torch.onnx.export 중 class 함수의 오류

Semantic Segmentation 모델 중 FastFCN이란 모델을 학습 한 후 onnx 파일로 export 하던 중 오류가 발생하였습니다.

Traceback (most recent call last):
  File "show.py", line 123, in <module>
    evaluatev0(dspth='/home/im/STDC-Seg/data', scale=1.0)
  File "show.py", line 115, in evaluatev0
    torch.onnx.export(net, dummy_input, "fastfcn.onnx", verbose=True, input_names=["input"], output_names=["output"], export_params=True, opset_version=11)
  File "/home/im/anaconda3/envs/fastfcn/lib/python3.6/site-packages/torch/onnx/__init__.py", line 208, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/im/anaconda3/envs/fastfcn/lib/python3.6/site-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/home/im/anaconda3/envs/fastfcn/lib/python3.6/site-packages/torch/onnx/utils.py", line 545, in _export
    val_add_node_names, val_use_external_data_format, model_file_location)
RuntimeError: ONNX export failed: Couldn't export Python operator ScaledL2

찾아본 결과 Python native 함수만을 사용해야 한다고 하는데, class를 어떻게 바꾸어 줘야 할 지 방법을 전혀 모르겠습니다.
해당 부분은 다음과 같습니다.
A = F.softmax(scaled_l2(X, self.codewords, self.scale), dim=2)
이곳에서 사용한 scaled_l2은 다음과 같습니다.

class ScaledL2(Function):
    @staticmethod
    def forward(ctx, X, C, S):
        SL = (X.unsqueeze(2).expand(X.size(0), X.size(1), C.size(0), C.size(1)) -
              C.unsqueeze(0).unsqueeze(0)).pow_(2).sum(3).mul_(S.view(1, 1, C.size(0)))
        ctx.save_for_backward(X, C, S, SL)
        print("forward")
        print("ctx:",ctx)
        print("X:", X.shape)
        print("C:", C.shape)
        print("S:", S.shape)
        print("SL:", SL.shape)
        return SL

    @staticmethod
    def backward(ctx, GSL):
        X, C, S, SL = ctx.saved_variables

        tmp = (X.unsqueeze(2).expand(X.size(0), X.size(1), C.size(0), C.size(1)) - C.unsqueeze(0).unsqueeze(0)).mul_(
            (2 * GSL).mul_(S.view(1, 1, C.size(0))).unsqueeze(3)
        )

        GX = tmp.sum(2)
        GC = tmp.sum((0, 1)).mul_(-1)
        GS = SL.div(S.view(1, 1, C.size(0))).mul_(GSL).sum((0, 1))
        print("backward")
        print("GX:", GX.shape)
        print("GC:", GC.shape)
        print("GS:", GS.shape)
        return GX, GC, GS

def scaled_l2(X, C, S):
    return ScaledL2.apply(X, C, S)

사용하고 있는 버전은 다음과 같습니다.
torch: 1.6.0
CUDA: 10.1
CUDNN: 7.6.5
tensorrt: 5.1.5.0

좋아요 1