def forward(self, last_state, states):
if len(states.size()) == 2: states = states.unsqueeze(0)
sequence_length, batch_size, state_dim = states.size()
transformed_last_state = last_state @ self.projection
transformed_last_state = transformed_last_state.expand(sequence_length, batch_size, self.encoder_dim)
transformed_last_state = transformed_last_state.transpose(0, 1).contiguous()
transformed_last_state = transformed_last_state.view(batch_size, -1)
states = states.transpose(0, 1).contiguous()
states = states.view(batch_size, -1)
energies = transformed_last_state * states
energies = energies.sum(dim=1)
if self.encoder_dim is not None:
attention_weights = torch.cat([torch.exp(energies[0]), F.softmax(energies[1:])], dim=0)
else:
attention_weights = F.softmax(energies)
return attention_weights
评论列表
文章目录