Model.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:MemNN 作者: berlino 项目源码 文件源码
def forward(self, qu, w, cand):
        qu = Variable(qu)
        w = Variable(w)
        cand = Variable(cand)
        embed_q = self.embed_B(qu)
        embed_w1 = self.embed_A(w)
        embed_c = self.embed_C(cand)

        #pdb.set_trace()
        q_state = torch.sum(embed_q, 1).squeeze(1)
        w1_state = torch.sum(embed_w1, 1).squeeze(1)

        sent_dot = torch.mm(q_state, torch.transpose(w1_state, 0, 1))
        sent_att = F.softmax(sent_dot)

        q_rnn_state = self.rnn_qus(embed_q, self.h0_q)[-1].squeeze(0)
        #pdb.set_trace()

        action = sent_att.multinomial()

        sent = embed_w1[action.data[0]]
        sent_state = self.rnn_doc(sent, self.h0_doc)[-1].squeeze(0)
        q_state = torch.add(q_state, sent_state)

        f_feat = torch.mm(q_state, torch.transpose(embed_c, 0, 1))
        reward_prob = F.log_softmax(f_feat).squeeze(0)

        return action, reward_prob
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号