def forward(self, qu, w, e_p):
qu = Variable(qu)
w = Variable(w)
embed_q = self.embed(qu)
embed_w = self.embed(w)
s_ = embed_w.size()
b_size = s_[0]
#pdb.set_trace()
h0_doc = Variable(torch.cat([self.h0_doc for _ in range(b_size)], 1))
out_qus, h_qus = self.rnn_qus(embed_q, self.h0_q)
out_doc, h_doc = self.rnn_doc(embed_w, h0_doc)
q_state = torch.cat([out_qus[0,-1,:self.config.rnn_fea_size], out_qus[0,0,self.config.rnn_fea_size:]],0)
# token attention
doc_tit_ent_dot = []
doc_tit_ent = []
doc_states = []
for i,k in enumerate(e_p):
# memory
t_e_v = self.cat(out_doc[i,1], out_doc[i,k])
# dot product
title = torch.dot(out_doc[i,1], q_state)
entity = torch.dot(out_doc[i,k], q_state)
token_att = torch.cat([title, entity],0).unsqueeze(0)
s_m = F.softmax(token_att)
att_v = torch.mm(s_m, t_e_v)
doc_tit_ent.append(att_v)
# concate start and end
state_ = torch.cat([out_doc[i,-1,:self.config.rnn_fea_size], out_doc[i,0,self.config.rnn_fea_size:]],0)
doc_states.append(state_.unsqueeze(0))
#pdb.set_trace()
t_e_vecs = torch.cat(doc_tit_ent,0)
# sentence attention
doc_states_v = torch.cat(doc_states, 0)
doc_dot = torch.mm(doc_states_v, q_state.unsqueeze(1))
doc_sm = F.softmax(doc_dot)
t_doc_feat = torch.add(doc_states_v, t_e_vecs)
doc_feat = torch.mm(doc_sm.view(1,-1), t_doc_feat)
score = torch.mm(self.embed.weight, doc_feat.view(-1,1)).view(1,-1)
score_n = F.log_softmax(score)
return score_n
评论列表
文章目录