def forward(self, input, source_hids):
# input: bsz x input_embed_dim
# source_hids: srclen x bsz x output_embed_dim
# x: bsz x output_embed_dim
x = self.input_proj(input)
# compute attention
attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2)
attn_scores = F.softmax(attn_scores.t()).t() # srclen x bsz
# sum weighted sources
x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0)
x = F.tanh(self.output_proj(torch.cat((x, input), dim=1)))
return x, attn_scores
评论列表
文章目录