def score_sent(self, sent, lang_h, ctx_h, temperature):
"""Computes likelihood of a given sentence."""
score = 0
# remove batch dimension from the language and context hidden states
lang_h = lang_h.squeeze(1)
ctx_h = ctx_h.squeeze(1)
inpt = Variable(torch.LongTensor(1))
inpt.data.fill_(self.word_dict.get_idx('YOU:'))
inpt = self.to_device(inpt)
lang_hs = []
for word in sent:
# add the context to the word embedding
inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1)
# update RNN state with last word
lang_h = self.writer(inpt_emb, lang_h)
lang_hs.append(lang_h)
# decode words using the inverse of the word embedding matrix
out = self.decoder(lang_h)
scores = F.linear(out, self.word_encoder.weight).div(temperature)
# subtract constant to avoid overflows in exponentiation
scores = scores.add(-scores.max().data[0]).squeeze(0)
mask = Variable(self.special_token_mask)
scores = scores.add(mask)
logprob = F.log_softmax(scores)
score += logprob[word[0]].data[0]
inpt = Variable(word)
# update the hidden state with the <eos> token
inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1)
lang_h = self.writer(inpt_emb, lang_h)
lang_hs.append(lang_h)
# add batch dimension back
lang_h = lang_h.unsqueeze(1)
return score, lang_h, torch.cat(lang_hs)
dialog_model.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录