Model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:MemNN 作者: berlino 项目源码 文件源码
def forward(self, qu, key, value, cand):
        qu = Variable(qu)
        key = Variable(key)
        value = Variable(value)
        cand = Variable(cand)
        embed_q = self.embed_B(qu)
        embed_w1 = self.embed_A(key)
        embed_w2 = self.embed_C(value)
        embed_c = self.embed_C(cand)

        #pdb.set_trace()
        q_state = torch.sum(embed_q, 1).squeeze(1)
        w1_state = torch.sum(embed_w1, 1).squeeze(1)
        w2_state = embed_w2

        for _ in range(self.config.hop):
            sent_dot = torch.mm(q_state, torch.transpose(w1_state, 0, 1))
            sent_att = F.softmax(sent_dot)

            a_dot = torch.mm(sent_att, w2_state)
            a_dot = self.H(a_dot)
            q_state = torch.add(a_dot, q_state)

        f_feat = torch.mm(q_state, torch.transpose(embed_c, 0, 1))
        score = F.log_softmax(f_feat)
        return score
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号