def forward(self, dec_out, enc_outs, enc_att=None, mask=None):
"""
Parameters:
-----------
- dec_out: torch.Tensor(batch_size x hid_dim)
- enc_outs: torch.Tensor(seq_len x batch_size x hid_dim)
- enc_att: (optional), torch.Tensor(seq_len x batch_size x att_dim)
- mask: (optional), torch.ByteTensor(batch_size x seq_len)
"""
# (batch x seq_len)
weights = self.scorer(dec_out, enc_outs, enc_att=enc_att)
if mask is not None:
# weights = weights * mask.float()
weights.data.masked_fill_(1 - mask.data, -float('inf'))
weights = F.softmax(weights, dim=1)
# (eq 7)
context = weights.unsqueeze(1).bmm(enc_outs.transpose(0, 1)).squeeze(1)
# (eq 5) linear out combining context and hidden
context = F.tanh(self.linear_out(torch.cat([context, dec_out], 1)))
return context, weights
评论列表
文章目录