ONNX로 변환된 Node와 원본 PyTorch module 매핑

혹시 PyTorch 모델을 onnx로 export했을 때 각 PyTorch module이 어떤 onnx Node로 변경 (또는 매핑)되었는 지 알 수 있거나 관련 경험이 있으신 분이 계실까요? 그래프 생성과정에서 onnx optimization (layer fusion 등)이 들어가다보니 코드 레벨로 파악하기가 어려운 것 같은데요

예를 들어 PyTorch module의 레이어가 self.conv, self.relu, self.bn 세 개로 구성되어있고
이것을 torch.onnx.export로 내보낸 ONNX의 Node 정보가 있을텐데
기존 PyTorch module과 Node 가 어떻게 매핑되었는 지 알 수 있는 방법이 있을까요?

  • 아래 답변은 OpenAI의 ChatGPT로 자동 생성한 것으로, 정확하지 않을 수 있습니다.
  • 잘못된 답변을 보셨다면 댓글로 첨삭 지도 부탁드립니다. :sweat_smile:

ONNX 모델이 생성될 때 PyTorch의 모듈과 ONNX의 노드 간의 직접적인 매핑은 존재하지 않습니다. 따라서 ONNX 모델에서 각각의 노드가 PyTorch의 어떤 모듈에 해당하는 것인지 알아내기 위해서는 ONNX 모델의 구조를 이해해야 합니다.

ONNX 모델은 노드(Node)와 엣지(Edge)로 이루어진 그래프(Graph) 구조로 이루어져 있습니다. 각 노드는 연산자(Ops)를 나타내며, 입력(input)과 출력(output)을 받습니다. 따라서 PyTorch 모듈과 ONNX 노드 사이의 매핑은 ONNX 모델의 그래프 구조를 분석하여 유추할 수 있습니다.

예를들어, PyTorch 모델에서 self.conv, self.relu, self.bn이 연속해서 적용된 경우, 이를 ONNX 모델로 변환하면 Conv+BatchNorm+Relu 노드로 변환됩니다. 이때 Conv노드에 들어가는 weight와 bias는 각각 PyTorch 모듈의 self.conv.weight과 self.conv.bias입니다. ONNX 모델의 입력 순서는 PyTorch 모델의 순서와 다른 경우도 있기 때문에 이에 대한 대응도 필요합니다. 이러한 변환 과정은 onnx를 생성할 때 자동으로 수행됩니다.

ONNX 모델의 그래프 구조 및 각 노드들이 어떤 입력을 받고, 어떤 출력을 내보내는지등의 정보는 ONNX 모델의 메타데이터(metadata)에 저장되어 있습니다. 이 메타데이터는 onnx.helper.printable_graph 함수를 사용하여 볼 수 있습니다.

하지만, 이 전체 프로세스는 매우 복잡하며 구체적인 것들은 다양한 옵션과 가정에 따라 변경되기 때문에, 일반적인 규칙은 없습니다.따라서, Node와 PyTorch module 사이의 매핑을 알고 싶다면, 모델의 구조와 ONNX의 연산자, 입력 및 출력에 대한 이해와 함께, 각 모둘에 대한 ONNX 노드를 하나씩 찾아가는 과정이 필요합니다.이번에 참고하실 가이드가 도움이 되실 수 있습니다. link


  • 답변이 어떠셨나요? 괜찮았다고 생각하신다면 아래 :heart:를 한 번 눌러주세요!