def forward(self, x, target_embedding, encoder_out):
residual = x
# attention
x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
x = self.bmm(x, encoder_out[0])
# softmax over last dim
sz = x.size()
x = F.softmax(x.view(sz[0] * sz[1], sz[2]))
x = x.view(sz)
attn_scores = x
x = self.bmm(x, encoder_out[1])
# scale attention output
s = encoder_out[1].size(1)
x = x * (s * math.sqrt(1.0 / s))
# project back
x = (self.out_projection(x) + residual) * math.sqrt(0.5)
return x, attn_scores
评论列表
文章目录