Semantic Segmentation 모델 중 FastFCN이란 모델을 학습 한 후 onnx 파일로 export 하던 중 오류가 발생하였습니다.
Traceback (most recent call last):
File "show.py", line 123, in <module>
evaluatev0(dspth='/home/im/STDC-Seg/data', scale=1.0)
File "show.py", line 115, in evaluatev0
torch.onnx.export(net, dummy_input, "fastfcn.onnx", verbose=True, input_names=["input"], output_names=["output"], export_params=True, opset_version=11)
File "/home/im/anaconda3/envs/fastfcn/lib/python3.6/site-packages/torch/onnx/__init__.py", line 208, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/home/im/anaconda3/envs/fastfcn/lib/python3.6/site-packages/torch/onnx/utils.py", line 92, in export
use_external_data_format=use_external_data_format)
File "/home/im/anaconda3/envs/fastfcn/lib/python3.6/site-packages/torch/onnx/utils.py", line 545, in _export
val_add_node_names, val_use_external_data_format, model_file_location)
RuntimeError: ONNX export failed: Couldn't export Python operator ScaledL2
찾아본 결과 Python native 함수만을 사용해야 한다고 하는데, class를 어떻게 바꾸어 줘야 할 지 방법을 전혀 모르겠습니다.
해당 부분은 다음과 같습니다.
A = F.softmax(scaled_l2(X, self.codewords, self.scale), dim=2)
이곳에서 사용한 scaled_l2은 다음과 같습니다.
class ScaledL2(Function):
@staticmethod
def forward(ctx, X, C, S):
SL = (X.unsqueeze(2).expand(X.size(0), X.size(1), C.size(0), C.size(1)) -
C.unsqueeze(0).unsqueeze(0)).pow_(2).sum(3).mul_(S.view(1, 1, C.size(0)))
ctx.save_for_backward(X, C, S, SL)
print("forward")
print("ctx:",ctx)
print("X:", X.shape)
print("C:", C.shape)
print("S:", S.shape)
print("SL:", SL.shape)
return SL
@staticmethod
def backward(ctx, GSL):
X, C, S, SL = ctx.saved_variables
tmp = (X.unsqueeze(2).expand(X.size(0), X.size(1), C.size(0), C.size(1)) - C.unsqueeze(0).unsqueeze(0)).mul_(
(2 * GSL).mul_(S.view(1, 1, C.size(0))).unsqueeze(3)
)
GX = tmp.sum(2)
GC = tmp.sum((0, 1)).mul_(-1)
GS = SL.div(S.view(1, 1, C.size(0))).mul_(GSL).sum((0, 1))
print("backward")
print("GX:", GX.shape)
print("GC:", GC.shape)
print("GS:", GS.shape)
return GX, GC, GS
def scaled_l2(X, C, S):
return ScaledL2.apply(X, C, S)
사용하고 있는 버전은 다음과 같습니다.
torch: 1.6.0
CUDA: 10.1
CUDNN: 7.6.5
tensorrt: 5.1.5.0