Apple M1/M2이 탑재된 장치에서 GPU 가속을 사용하려면 어떻게 해야 하나요?

알아두기

  • 이 기능은 Apple M1 칩이 탑재된 기기에서만 사용이 가능합니다.
  • Apple M1 칩에서의 PyTorch GPU 가속 기능은 아직 정식 릴리즈가 되지 않았습니다. (2022년 5월 20일 현재)
  • 따라서 최신 기능이 포함된 Preview(Nightly) 버전을 사용하셔야 하며, 이 기능은 불안정할 수 있습니다.
  • 가급적 pyenv나 conda 등을 사용하여 별도의 가상 환경에서 테스트 용도로만 사용하시기를 권해드립니다.

답변

Preview 버전 설치하기

  • 직접 설치하기 페이지의 설명을 따라 Preview 버전을 설치합니다.
  • 또는, pip 사용자의 경우 다음 명령어로 설치할 수 있습니다.
    pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
    
  • 또는, conda 사용자의 경우 다음 명령어로 설치할 수 있어야 합니다.
    conda install pytorch torchvision torchaudio -c pytorch-nightly
    
  • :raised_hand_with_fingers_splayed: 현재 사용하는 conda가 ARM64 아키텍처를 지원하지 않을 수 있습니다.
    아래 코드를 실행하여 ARM64를 지원하는지 반드시 확인해주세요!
    python -c 'import platform;print(platform.platform())'
    
    위 코드의 실행 결과가 macos-12.4-arm64-arm-64bit 등과 같이 arm64를 반드시 포함하고 있어야 합니다.
    만약 ARM64를 지원하지 않는다면, conda를 재설치하셔야 합니다.

현재 설치된 PyTorch 버전이 MPS Backend를 지원하는지 확인하기

  • 현재 설치된 PyTorch의 버전이 1.12 또는 그 이상인지와,
    PyTorch가 M1 칩을 사용하기 위한 MPS 장치를 사용할 수 있도록 빌드되어 있는지를 확인해야 합니다.
  • 아래 명령어로 PyTorch 버전과 MPS 장치 사용이 가능하도록 빌드되었는지를 확인하실 수 있습니다.
    >>> import torch
    >>> print(torch.__version__) # 설치된 PyTorch 버전을 확인합니다. 1.12 이상이어야 합니다.
    1.12.0.dev20220519
    >>> print(torch.backends.mps.is_built()) # MPS 장치를 지원하도록 빌드되어있는지 확인합니다. True여야 합니다.
    True
    

MPS 장치가 사용 가능한지 확인하기

  • 현재 설치된 장치에 사용 가능한 MPS 장치가 있는지 확인합니다.
  • NVIDIA GPU의 경우 CUDA에서 사용 가능한 장치가 있는지를 torch.cuda.is_available()로 확인하는 것과 비슷합니다.
    >>> import torch
    >>> print(torch.backends.mps.is_available()) # MPS 장치가 사용 가능한지 확인합니다. True여야 합니다.
    True
    

MPS 장치를 사용하도록 지정하기

  • 기존의 NVIDIA GPU에서 cuda를 사용했던 것처럼, mps 라는 이름으로 장치를 사용할 수 있습니다.
  • PyTorch,org 문서에서 제공하는 예시 코드는 아래와 같습니다.
    import torch
    
    mps_device = torch.device("mps")
    
    # MPS 장치에 바로 tensor를 생성합니다.
    x = torch.ones(5, device=mps_device)
    # 또는
    x = torch.ones(5, device="mps")
    
    # GPU 상에서 연산을 진행합니다.
    y = x * 2
    
    # 또는, 다른 장치와 마찬가지로 MPS로 이동할 수도 있습니다.
    model = YourFavoriteNet()  # 어떤 모델의 객체를 생성한 뒤,
    model.to(mps_device)       # MPS 장치로 이동합니다.
    
    # 이제 모델과 텐서를 호출하면 GPU에서 연산이 이뤄집니다.
    pred = model(x)
    

더 알아보기

4개의 좋아요