안녕하세요 최근에 M1칩의 gpu를 이용한 파이토치가 릴리즈 된 기쁜소식을 듣고,
업데이트를 한 뒤(1.12.) 기존의 cuda로 돌리던 모델을 한번 mps로 바꾸어서 돌렸는데 다음과 같은 에러 메세지가 발생하였습니다.
/AppleInternal/Library/BuildRoots/20d6c351-ee94-11ec-bcaf-7247572f23b4/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/Operations/MPSGraphConvolutionOps.mm:346: failed assertion `sourceTensor rank must be 4’
구글에 검색을 해보니 아직까지 mps가 conv3D를 지원하지 않아서 위와 같은 오류가 발생한다는 글(torch.nn.Conv3D on MPS backend · Issue #77818 · pytorch/pytorch · GitHub)을 하나 발견하였는데,
혹시 해결방법이 있을까요?
감사합니다.