def forward(self, last_state, states, mask=None):
sequence_length, batch_size, hidden_dim = states.size()
last_state = last_state.unsqueeze(0).expand(sequence_length, batch_size, last_state.size(1))
if self.mode == "dot":
energies = last_state * states
energies = energies.sum(dim=2).squeeze()
elif self.mode == "general":
expanded_projection = self.projection.expand(sequence_length, *self.projection.size())
energies = last_state * states.bmm(expanded_projection)
energies = energies.sum(dim=2).squeeze()
elif self.mode == "concat":
expanded_reduction = self.reduction.expand(sequence_length, *self.reduction.size())
expanded_projection = self.projection.expand(sequence_length, *self.projection.size())
energies = F.tanh(torch.cat([last_state, states], dim=2).bmm(expanded_reduction))
energies = energies.bmm(expanded_projection).squeeze()
if type(mask) == torch.autograd.Variable:
energies = energies + ((mask == 0).float() * -10000)
attention_weights = F.softmax(energies)
return attention_weights
评论列表
文章目录