Embedding 함수

안녕하세요! 파이토치 입문자 입니다. nn.Embedding 함수 관련해서 질문을 드리고 싶은데요? 아래와 같은 예시코드가 있을 때, nn.Embedding 함수의 첫번째 인자(num_embeddings 값)에 10보다 작은 값이 들어가면 왜 에러가 발생하는지 이해가 잘 가지 않습니다.

import torch.nn as nn
import torch

input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) # (2, 4)
embed = nn.Embedding(10, 3)  # (10,3) shape의 임베딩 파라미터
embed(input).shape

nn.Embedding(10, 3) 으로 하면 (10, 3) shape의 임베딩 파라미터가 생긴다는 것은 알겠는데…왜 아래처럼 nn.Embedding 인자의 첫번째 파라미터 값을 10보다 작은 값으로 하면 에러가 발생하는 건가요? 구글링해도 잘 모르겠네요…

import torch.nn as nn
import torch

input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) # (2, 4)
embed = nn.Embedding(9, 3) # index error
embed(input).shape

여기 예제 참조하시면 아실수 있습니다 첫번째 인자는 임베딩의 크기입니다

word_to_ix = {"hello": 0, "world": 1}
embeds = nn.Embedding(2, 5)  # 2 words in vocab, 5 dimensional embeddings
lookup_tensor = torch.tensor([word_to_ix["hello"]], dtype=torch.long)
hello_embed = embeds(lookup_tensor)
print(hello_embed)

해당 데이터의 크기가 얼마인지 보는것이므로 그거보다 크게 나와야합니다
torch.max(input) 해서 나온 값보다 +1 해서 해주는게 안전합니다

즉 embedding(9,3)으로 해줄거면 input에 있는 최고값은 8로 해주면 됩니다

1개의 좋아요