MPS backend적용시 에러발생 문의

안녕하세요 최근에 M1칩의 gpu를 이용한 파이토치가 릴리즈 된 기쁜소식을 듣고,
업데이트를 한 뒤(1.12.) 기존의 cuda로 돌리던 모델을 한번 mps로 바꾸어서 돌렸는데 다음과 같은 에러 메세지가 발생하였습니다.

/AppleInternal/Library/BuildRoots/20d6c351-ee94-11ec-bcaf-7247572f23b4/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/Operations/MPSGraphConvolutionOps.mm:346: failed assertion `sourceTensor rank must be 4’

구글에 검색을 해보니 아직까지 mps가 conv3D를 지원하지 않아서 위와 같은 오류가 발생한다는 글(torch.nn.Conv3D on MPS backend · Issue #77818 · pytorch/pytorch · GitHub)을 하나 발견하였는데,
혹시 해결방법이 있을까요?

감사합니다.

저도 비슷한 문제(다른 op) 때문에 간단한 모델 mps backend 써보려다 포기했는데요…
아직 mps backend 기반 op 들이 많이 만들어지지 않아, 실질적으로 사용이 어려워보입니다 ㅠㅠ

이 글은 마지막 댓글이 달리고 30일 뒤 자동적으로 닫혔습니다. 새 댓글을 다실 수 없습니다.