행렬 크기가 클수록 속도 차이가 많이납니다만, 너무 크게하면 오류가 뜹니다.
아직은 pytorch 베타버전이란 그런것 같습니다.
CPU와 MPS 차이가 단순연산에선 최대 100배 정도 나는것 같습니다.
위의 소스도 함께 올립니다.
import sys
import torch
import platform
import time
print("OS version : ",platform.platform())
print("Python version : ",sys.version)
print("Torch Version : ",torch._ version _)
print("")
print("MPS Bulit : ",torch.backends.mps.is_built())
print("MPS avail : ",torch.backends.mps.is_available())
print("")
a=torch.rand(10000,5000)
b=torch.rand(5000,10000)
tic=time.time()
torch.matmul(a,b)
toc=time.time()
print(a.device," : ",toc-tic)
c=torch.rand(10000,5000).to("mps")
d=torch.rand(5000,10000).to("mps")
tic=time.time()
torch.matmul(c,d)
toc=time.time()
print(c.device," : ",toc-tic)