Model.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:MemNN 作者: berlino 项目源码 文件源码
def forward(self, qu, w, e_p):
        qu = Variable(qu)
        w = Variable(w)
        embed_q = self.embed(qu)
        embed_w = self.embed(w)
        s_ = embed_w.size()
        b_size = s_[0]

        #pdb.set_trace()
        h0_doc = Variable(torch.cat([self.h0_doc for _ in range(b_size)], 1))
        out_qus, h_qus = self.rnn_qus(embed_q, self.h0_q)
        out_doc, h_doc = self.rnn_doc(embed_w, h0_doc)

        q_state = torch.cat([out_qus[0,-1,:self.config.rnn_fea_size], out_qus[0,0,self.config.rnn_fea_size:]],0)

        # token attention
        doc_tit_ent_dot = []
        doc_tit_ent = []
        doc_states = []
        for i,k in enumerate(e_p):
            # memory
            t_e_v = self.cat(out_doc[i,1], out_doc[i,k])
            # dot product
            title = torch.dot(out_doc[i,1], q_state)
            entity = torch.dot(out_doc[i,k], q_state)
            token_att = torch.cat([title, entity],0).unsqueeze(0)
            s_m = F.softmax(token_att)
            att_v = torch.mm(s_m, t_e_v)
            doc_tit_ent.append(att_v)
            # concate start and end
            state_ = torch.cat([out_doc[i,-1,:self.config.rnn_fea_size], out_doc[i,0,self.config.rnn_fea_size:]],0)
            doc_states.append(state_.unsqueeze(0))
        #pdb.set_trace()
        t_e_vecs = torch.cat(doc_tit_ent,0)

        # sentence attention
        doc_states_v = torch.cat(doc_states, 0)
        doc_dot = torch.mm(doc_states_v, q_state.unsqueeze(1))
        doc_sm = F.softmax(doc_dot)
        t_doc_feat = torch.add(doc_states_v, t_e_vecs)
        doc_feat = torch.mm(doc_sm.view(1,-1), t_doc_feat)

        score = torch.mm(self.embed.weight, doc_feat.view(-1,1)).view(1,-1)
        score_n = F.log_softmax(score)

        return score_n
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号