def forward(self, qu, w, cand):
qu = Variable(qu)
w = Variable(w)
cand = Variable(cand)
embed_q = self.embed_B(qu)
embed_w1 = self.embed_A(w)
embed_c = self.embed_C(cand)
#pdb.set_trace()
q_state = torch.sum(embed_q, 1).squeeze(1)
w1_state = torch.sum(embed_w1, 1).squeeze(1)
sent_dot = torch.mm(q_state, torch.transpose(w1_state, 0, 1))
sent_att = F.softmax(sent_dot)
q_rnn_state = self.rnn_qus(embed_q, self.h0_q)[-1].squeeze(0)
#pdb.set_trace()
action = sent_att.multinomial()
sent = embed_w1[action.data[0]]
sent_state = self.rnn_doc(sent, self.h0_doc)[-1].squeeze(0)
q_state = torch.add(q_state, sent_state)
f_feat = torch.mm(q_state, torch.transpose(embed_c, 0, 1))
reward_prob = F.log_softmax(f_feat).squeeze(0)
return action, reward_prob
评论列表
文章目录