def _attend(self, p):
p = self.xh(p)
p = F.expand_dims(p, 1)
p = F.broadcast_to(p, self.shape2)
h = F.tanh(self.h + p)
shape3 = (self.batchsize * self.src_len, self.dim_hid)
h_reshaped = F.reshape(h, shape3)
weight_reshaped = self.hw(h_reshaped)
weight = F.reshape(weight_reshaped, (self.batchsize, self.src_len, 1))
weight = F.where(self.mask, weight, self.minf)
attention = F.softmax(weight)
return attention
评论列表
文章目录