pytorch로 Attention 기법이 적용된 Sequence to Sequence 모델을 구현했는데요! 구현한 코드가 이론적으로 맞지 않는 부분이 없는지 검증해보고자 요청드립니다!
import torch
import torch.nn as nn
def verbose_shape(*tensors):
for t in tensors:
print(t.shape, end=' ')
class Encoder(nn.Module):
def __init__(self, vocab_size, embedding_size, hidden_size):
super(Encoder, self).__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.embed_layer = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=embedding_size)
self.lstm = nn.LSTM(input_size=self.embedding_size, hidden_size=self.hidden_size, num_layers=1, batch_first=True)
self.fc = nn.Linear(in_features=self.hidden_size, out_features=self.vocab_size)
def forward(self, x):
x = self.embed_layer(x)
hs, (h, c) = self.lstm(x)
z = self.fc(hs)
a = F.softmax(z, dim=-1)
y = torch.argmax(a, dim=-1)
return hs, h, c, a, y
class Decoder(nn.Module):
def __init__(self, vocab_size, embedding_size, hidden_size):
super(Decoder, self).__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.embed_layer = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_size)
self.lstm = nn.LSTM(input_size=self.embedding_size, hidden_size=self.hidden_size, num_layers=1, batch_first=True)
self.affine = nn.Linear(in_features=self.hidden_size*2, out_features=self.vocab_size)
def forward(self, enc_hs, enc_h, enc_c, enc_pred):
"""
enc_pred: encoder의 sequence 예측값(argmax한 값)
"""
x = self.embed_layer(enc_pred)
dec_hs, _ = self.lstm(x, (enc_h, enc_c))
#==========
# Attention
#==========
# 1.가중치 계산
s = torch.bmm(dec_hs, enc_hs.transpose(1, 2))
a = F.softmax(s, dim=1) # (28, 32, 32) = (batch, D-T, E-T) -> E-T 방향을 sum=1 이기 때문에 E-T 방향으로 softmax. 그러므로 dim=1
# 2.선택 작업
c = torch.bmm(a, enc_hs)
c_hs = torch.cat((c, dec_hs), dim=-1)
z = self.affine(c_hs)
y = F.softmax(z, dim=-1)
return y
# params
durations_vocab_size = 5
embedding_size = 128
hidden_size = 256
# dataset
x_durations = torch.randint(low=0, high=4, size=(28, 32))
# model
encoder = Encoder(durations_vocab_size, embedding_size, hidden_size)
decoder = Decoder(durations_vocab_size, embedding_size, hidden_size)
# forward
hs, h, c, y_prob, y = encoder(x_durations)
y = decoder(hs, h, c, y)