def forward(self, qu, w, cand):
qu = Variable(qu)
cand = Variable(cand)
embed_q = self.embed(qu)
embed_cand = self.embed(cand)
out, (self.h0, self.c0) = self.rnn(embed_q, (self.h0, self.c0))
self.h0.detach_()
self.c0.detach_()
q_state = out[:,-1,:]
f_fea_v = torch.mm(q_state, torch.transpose(embed_cand,0,1))
score_n = F.log_softmax(f_fea_v)
return score_n
评论列表
文章目录