class TimeAwareFeatureAttention(nn.Module):
def init(self, input_size, seq_length):
super(TimeAwareFeatureAttention, self).init()
self.W_f = nn.Linear(input_size, input_size) # Feature 중요도 학습
self.W_t = nn.Linear(seq_length, seq_length) # Time Step 중요도 학습
def forward(self, x):
# x shape: (batch, seq_length, features)
feature_weights = torch.softmax(self.W_f(x), dim=-1) # Feature별 중요도 학습
time_weights = torch.softmax(self.W_t(x.permute(0, 2, 1)), dim=-1).permute(0, 2, 1) # Time Step별 중요도 학습
return x * feature_weights * time_weights # Feature와 Time Step별 중요도 반영
class TemporalAttention(nn.Module):
def init(self, hidden_size):
super(TemporalAttention, self).init()
self.W_q = nn.Linear(hidden_size, hidden_size)
self.W_k = nn.Linear(hidden_size, hidden_size)
self.W_v = nn.Linear(hidden_size, hidden_size)
self.scale = torch.sqrt(torch.tensor(hidden_size, dtype=torch.float32, device=DEVICE))
def forward(self, query, values):
Q = self.W_q(query).unsqueeze(1) # (batch, 1, hidden_size)
K = self.W_k(values) # (batch, seq_length, hidden_size)
V = self.W_v(values) # (batch, seq_length, hidden_size)
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
attention_weights = torch.softmax(attention_scores, dim=-1)
context_vector = torch.matmul(attention_weights, V).squeeze(1)
return context_vector, attention_weights
class LSTMFeatureTemporalAttentionModel(nn.Module):
def init(self, input_size, hidden_size, num_layers):
super(LSTMFeatureTemporalAttentionModel, self).init()
self.feature_attention = FeatureAttention(input_size)
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.temporal_attention = TemporalAttention(hidden_size)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
x = self.feature_attention(x) # Feature-Level Attention 적용
lstm_out, (hn, cn) = self.lstm(x)
query = hn[-1]
context_vector, attention_weights = self.temporal_attention(query, lstm_out)
output = self.fc(context_vector)
return output, attention_weights
제가 구성한 모델인데
이를 수식화 및 도식화 하면 어떻게 표현해야 할까요?
그리고 논리적으로 오류는 없을까요?