파이토치로 숫자를 원-핫(One-hot) 형식으로 어떻게 만드나요?

공식 홈페이지StackOverflow 등에서 자주 보이는 질문과 답변을 번역하고 있습니다.

다음 링크에서 원문을 함께 찾아보실 수 있습니다.


질문

  • mnist와 같은 분류 문제에서, 정답 레이블(label)을 정수(int) 값으로 가지고 있습니다.
  • 이 정수 값을 어떻게 원-핫 벡터(One-hot vector) 형식으로 바꿀 수 있나요?

답변

  • torch.nn.functional.one_hot() 함수를 사용하여 원-핫 형식으로 바꿀 수 있습니다.
  • 아래와 같은 정답 레이블(label)이 있다고 가정합니다.
    >>> import torch
    >>> torch.arange(0, 5) % 3
    tensor([0, 1, 2, 0, 1])
    
  • 이 때. torch.nn.functional.one_hot()의 동작은 다음과 같습니다.
    >>> import torch.nn.functional as F
    >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
    tensor([[1, 0, 0, 0, 0],
            [0, 1, 0, 0, 0],
            [0, 0, 1, 0, 0],
            [1, 0, 0, 0, 0],
            [0, 1, 0, 0, 0]])
    
  • num_classes는 정답 클래스의 총 개수로, 없는 경우 주어진 텐서의 가장 큰 값보다 1만큼 큰 수로 설정됩니다.
    >>> import torch.nn.functional as F
    >>> F.one_hot(torch.arange(0, 5) % 3)
    tensor([[1, 0, 0],
            [0, 1, 0],
            [0, 0, 1],
            [1, 0, 0],
            [0, 1, 0]])
    

더 알아보기