def forward(self, input, context):
"""
input: batch x dim
context: batch x sourceL x dim
"""
targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1
# Get attention
attn = torch.bmm(context, targetT).squeeze(2) # batch x sourceL
if self.mask is not None:
attn.data.masked_fill_(self.mask, -float('inf'))
attn = self.sm(attn)
attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL
weightedContext = torch.bmm(attn3, context).squeeze(1) # batch x dim
contextCombined = torch.cat((weightedContext, input), 1)
contextOutput = self.tanh(self.linear_out(contextCombined))
return contextOutput, attn
评论列表
文章目录