공식 홈페이지와 StackOverflow 등에서 자주 보이는 질문과 답변을 번역하고 있습니다.
다음 링크에서 원문을 함께 찾아보실 수 있습니다.
질문
-
mnist
와 같은 분류 문제에서, 정답 레이블(label)을 정수(int) 값으로 가지고 있습니다. - 이 정수 값을 어떻게 원-핫 벡터(One-hot vector) 형식으로 바꿀 수 있나요?
공식 홈페이지와 StackOverflow 등에서 자주 보이는 질문과 답변을 번역하고 있습니다.
다음 링크에서 원문을 함께 찾아보실 수 있습니다.
mnist
와 같은 분류 문제에서, 정답 레이블(label)을 정수(int) 값으로 가지고 있습니다.torch.nn.functional.one_hot()
함수를 사용하여 원-핫 형식으로 바꿀 수 있습니다.>>> import torch
>>> torch.arange(0, 5) % 3
tensor([0, 1, 2, 0, 1])
torch.nn.functional.one_hot()
의 동작은 다음과 같습니다.>>> import torch.nn.functional as F
>>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
tensor([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0]])
num_classes
는 정답 클래스의 총 개수로, 없는 경우 주어진 텐서의 가장 큰 값보다 1만큼 큰 수로 설정됩니다.>>> import torch.nn.functional as F
>>> F.one_hot(torch.arange(0, 5) % 3)
tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])