def forward(self, qu, key, value, cand):
qu = Variable(qu)
key = Variable(key)
value = Variable(value)
cand = Variable(cand)
embed_q = self.embed_B(qu)
embed_w1 = self.embed_A(key)
embed_w2 = self.embed_C(value)
embed_c = self.embed_C(cand)
#pdb.set_trace()
q_state = torch.sum(embed_q, 1).squeeze(1)
w1_state = torch.sum(embed_w1, 1).squeeze(1)
w2_state = embed_w2
for _ in range(self.config.hop):
sent_dot = torch.mm(q_state, torch.transpose(w1_state, 0, 1))
sent_att = F.softmax(sent_dot)
a_dot = torch.mm(sent_att, w2_state)
a_dot = self.H(a_dot)
q_state = torch.add(a_dot, q_state)
f_feat = torch.mm(q_state, torch.transpose(embed_c, 0, 1))
score = F.log_softmax(f_feat)
return score
评论列表
文章目录