torch모델을 tflite로 변환시 input shape 문제

안녕하세요,

모델 변환 과정에서 input shape과 관련된 궁금증이 생겨 글 남깁니다.

모델 변환 과정에 사용한 순서는 다음과 같습니다.
torch(N,C,H,W)->onnx(N,C,H,W)->tflite(N,H,W,C)

해당 모델에서 inference 결과값은 정상적으로 나오는 것을 확인 했습니다.

하지만, tflite는 input의 shape이 (N,H,W,C)을 지원하기 때문에

onnx->tflite과정에서 생성된 tflite의 모델 구조를 확인해 보면, 모든 convolution layer마다 transpose가 추가 되어 차원 순서를 변환하는 연산을 통해 모델의 inference 속도가 저하되는 문제가 발생합니다.

torch로 학습된 모델을 사용한다고 가정하였을때, 이러한 converter들을 사용하는 과정에서 input의 shape을 변경하는 방법이 궁금합니다.

좋아요 1

안녕하세요, @EunJae_Ha 님.

제가 겪어보지 못한 이슈라 조심스럽지만 한 번 아래와 같은 memory_format 파라매터를 사용해보시는건 어떠실까요?
https://tutorials.pytorch.kr/intermediate/memory_format_tutorial.html

아니면 keras나 tf 등으로 동일한 모델을 만드신 다음에 PyTorch의 각 Layer별 weight를 복사하시는 것도 방법이실 것 같습니다. ^^;