def forward(self, input, context):
"""Propogate input through the network.
input: batch x dim
context: batch x sourceL x dim
"""
target = self.linear_in(input).unsqueeze(2) # batch x dim x 1
# Get attention
attn = torch.bmm(context, target).squeeze(2) # batch x sourceL
attn = self.sm(attn)
attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
h_tilde = torch.cat((weighted_context, input), 1)
h_tilde = self.tanh(self.linear_out(h_tilde))
return h_tilde, attn
评论列表
文章目录