def attention(self, hidden, W1xe, hidden_encoder):
# train
W2xdn = torch.mm(hidden, self.W2)
W2xdn = W2xdn.unsqueeze(1).expand(self.batch_size, self.n + 1,
self.hidden_size)
u = (torch.bmm(torch.tanh(W1xe + W2xdn), self.v.unsqueeze(0)
.expand(self.batch_size, self.hidden_size, 1)))
u = u.squeeze()
# test
# W2xdn = torch.mm(hidden, self.W2)
# u = Variable(torch.zeros(self.batch_size, self.n + 1)).type(dtype)
# for n in xrange(self.n + 1):
# aux = torch.tanh(W1xe[:, n].squeeze() + W2xdn) # size bs x hidd
# aux2 = (torch.bmm(aux.unsqueeze(1), self.v.unsqueeze(0)
# .expand(self.batch_size, self.hidden_size, 1)))
# u[:, n] = aux2.squeeze()
return u
评论列表
文章目录