Merge.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:DCN 作者: alexnowakvila 项目源码 文件源码
def attention(self, hidden, W1xe, hidden_encoder):
        # train
        W2xdn = torch.mm(hidden, self.W2)
        W2xdn = W2xdn.unsqueeze(1).expand(self.batch_size, self.n + 1,
                                          self.hidden_size)
        u = (torch.bmm(torch.tanh(W1xe + W2xdn), self.v.unsqueeze(0)
             .expand(self.batch_size, self.hidden_size, 1)))
        u = u.squeeze()
        # test
        # W2xdn = torch.mm(hidden, self.W2)
        # u = Variable(torch.zeros(self.batch_size, self.n + 1)).type(dtype)
        # for n in xrange(self.n + 1):
        #     aux = torch.tanh(W1xe[:, n].squeeze() + W2xdn)  # size bs x hidd
        #     aux2 = (torch.bmm(aux.unsqueeze(1), self.v.unsqueeze(0)
        #             .expand(self.batch_size, self.hidden_size, 1)))
        #     u[:, n] = aux2.squeeze()
        return u
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号