def write_batch(self, bsz, lang_h, ctx_h, temperature, max_words=100):
"""Generate sentenses for a batch simultaneously."""
eod = self.word_dict.get_idx('<selection>')
# resize the language hidden and context hidden states
lang_h = lang_h.squeeze(0).expand(bsz, lang_h.size(2))
ctx_h = ctx_h.squeeze(0).expand(bsz, ctx_h.size(2))
# start the conversation with 'YOU:'
inpt = torch.LongTensor(bsz).fill_(self.word_dict.get_idx('YOU:'))
inpt = Variable(self.to_device(inpt))
outs, lang_hs = [], [lang_h.unsqueeze(0)]
done = set()
# generate until max_words are generated, or all the dialogues are done
for _ in range(max_words):
# embed the input
inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1)
# pass it through the writer and get new hidden state
lang_h = self.writer(inpt_emb, lang_h)
out = self.decoder(lang_h)
# tie weights with encoder
scores = F.linear(out, self.word_encoder.weight).div(temperature)
# subtract max to make softmax more stable
scores.sub_(scores.max(1, keepdim=True)[0].expand(scores.size(0), scores.size(1)))
out = torch.multinomial(scores.exp(), 1).squeeze(1)
# save outputs and hidden states
outs.append(out.unsqueeze(0))
lang_hs.append(lang_h.unsqueeze(0))
inpt = out
data = out.data.cpu()
# check if all the dialogues in the batch are done
for i in range(bsz):
if data[i] == eod:
done.add(i)
if len(done) == bsz:
break
# run it for the last word to get correct hidden states
inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1)
lang_h = self.writer(inpt_emb, lang_h)
lang_hs.append(lang_h.unsqueeze(0))
# concatenate outputs and hidden states into single tensors
return torch.cat(outs, 0), torch.cat(lang_hs, 0)
dialog_model.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录