현재 pytorch를 활용해서 Attention 기술이 적용된 LSTM 기반 Sequence to Sequence 모델을 만들려고 합니다. Encoder 부분은 잘 완성했는데, Decoder 부분에서 Attention을 구현하는 데서 이해가 되지 않는 부분이 있습니다.
Attention은 크게 2가지 부분으로 구성되어 있고, 1번째 단계는 Encoder의 hidden_state 벡터와 Decoder의 1개 hidden_state 벡터 간의 Weight(가중치)벡터를 계산하는 단계, 2번째 단계는 Encoder hidden_state 벡터에 (1번째 단계에서 구한) Weight 벡터를 곱해주어 맥락 벡터를 만드는 단계 총 2개로 알고 있습니다.
제가 문제에 직면한 단계는 2번째 단계인데요. Weight 벡터를 구하긴 했는데, 2번째 단계에서 Decoder에 들어오는 모든 시퀀스에 대해 맥락 벡터를 어떻게 계산해야 할지 모르겠습니다. Decoder에 들어오는 단일 시퀀스에 대해 맥락 벡터를 계산하는 것은 어떻게든 하겠는데, 이렇게 하면 for loop 구문을 사용해야 하는 등 텐서의 벡터화 연산을 사용하지 못하게 되므로 비효율적인 것 같다는 생각입니다. 텐서의 브로드캐스팅 기능을 사용해서 계산할 수 있을 것 같은데, 어떻게 해야 할지 모르겠습니다.
우선 제가 만들려는 모델 및 데이터셋 코드는 아래와 같습니다.
import torch
import torch.nn as nn
def verbose_shape(*tensors):
for tensor in tensors:
print(tensor.shape, end=' ')
class Encoder(nn.Module):
def __init__(self,
notes_vocab_size,
durations_vocab_size,
embedding_size,
hidden_size):
super(Encoder, self).__init__()
self.notes_vocab_size = notes_vocab_size
self.durations_vocab_size = durations_vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.notes_embed = nn.Embedding(num_embeddings=self.notes_vocab_size, embedding_dim=embedding_size)
self.durations_embed = nn.Embedding(num_embeddings=self.durations_vocab_size, embedding_dim=embedding_size)
self.lstm = nn.LSTM(input_size=self.embedding_size*2, hidden_size=self.hidden_size, num_layers=1, batch_first=True)
self.notes_fc = nn.Linear(in_features=self.hidden_size, out_features=self.notes_vocab_size)
self.durations_fc = nn.Linear(in_features=self.hidden_size, out_features=self.durations_vocab_size)
def forward(self, x_notes, x_durations):
x_notes_embed = self.notes_embed(x_notes)
x_durations_embed = self.durations_embed(x_durations)
x = torch.cat((x_notes_embed, x_durations_embed), dim=-1)
hs, (h, c) = self.lstm(x)
a_notes = F.softmax(self.notes_fc(hs), dim=-1) # (batch-size, sequence_length, vocab_size)
a_durations = F.softmax(self.durations_fc(hs), dim=-1) # one-hot encoding
notes_pred = torch.argmax(a_notes, dim=-1)
durations_pred = torch.argmax(a_durations, dim=-1)
return hs, h, c, notes_pred, durations_pred
class Decoder(nn.Module):
def __init__(self,
notes_vocab_size,
durations_vocab_size,
embedding_size,
hidden_size):
super(Decoder, self).__init__()
self.notes_vocab_size = notes_vocab_size
self.durations_vocab_size = durations_vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.notes_embed = nn.Embedding(num_embeddings=self.notes_vocab_size, embedding_dim=self.embedding_size)
self.durations_embed = nn.Embedding(num_embeddings=self.durations_vocab_size, embedding_dim=self.embedding_size)
self.lstm = nn.LSTM(input_size=self.embedding_size*2, hidden_size=self.hidden_size, num_layers=1, batch_first=True)
def forward(self, enc_hs, enc_h, enc_c, notes_pred, durations_pred):
x_notes_embed = self.notes_embed(notes_pred)
x_durations_embed = self.durations_embed(durations_pred)
x = torch.cat((x_notes_embed, x_durations_embed), dim=-1)
dec_hs, _ = self.lstm(x, (enc_h, enc_c))
alpha = torch.bmm(enc_hs, dec_hs.transpose(1, 2)) # 28무시하고 (32, 1) -> decoder의 1개 hidden_state가 encoder의 모든 hidden_state 과의 관계 weight
weight = F.softmax(alpha, dim=1) # (32, 1) -> column-vector 1개가 enc_hs에 곱해져야 함
verbose_shape(enc_hs, weight)
c = enc_hs * weight
return c
# params
notes_vocab_size = 32
durations_vocab_size = 5
embedding_size = 128
hidden_size = 256
# dataset
x_notes = torch.rand(28 ,32).to(torch.int32)
x_durations = torch.rand(28, 32).to(torch.int32)
# model
encoder = Encoder(notes_vocab_size, durations_vocab_size, embedding_size, hidden_size)
decoder = Decoder(notes_vocab_size, durations_vocab_size, embedding_size, hidden_size)
# forward
hs, h, c, notes_pred, durations_pred = encoder(x_notes, x_durations)
a = decoder(hs, h, c, notes_pred, durations_pred)
위 코드에서 enc_hs가 Encoder에서 나온 전체 hidden_state 벡터를 의미하고, weight가 위에서 이야기한 Attention 구현 단계 중 1단계를 수행하면서 계산되는 가중치를 의미합니다. enc_hs 벡터의 형상은 (28, 32, 256)이며 각 차원의 의미는 (batch_size, sequence_length, hidden_state_unit) 입니다. 그리고 weight 벡터의 형상은 (28, 32, 32) 이며 (batch_size, sequence_length, sequence_length) 입니다. (28, 32, 32) 형상 중 (1, 32, 1) 벡터는 batch_size가 1이면서 Decoder 인풋으로 들어온 하나의 hidden_state 벡터가 Encoder의 모든 hidden_state 벡터 간 계산된 각 weight를 의미합니다.
아마 설명이 이해가 잘 안될 수도 있는데요. 추가적인 답글 남겨주시면 보충 설명 드리도록 하겠습니다! 감사합니다.