def score(self, hidden, encoder_output):
if self.method == 'dot':
energy = torch.dot(hidden.view(-1), encoder_output.view(-1))
return energy
elif self.method == 'general':
energy = self.attn(encoder_output)
energy = torch.dot(hidden.view(-1), encoder_output.view(-1))
return energy
评论列表
文章目录