def _sample(self, state, context, mask, max_len=20):
"""
Performs sampling
"""
batch_size = state.size(0)
toks = [const_row(self.bos_token, batch_size, volatile=True)]
lens = torch.IntTensor(batch_size)
if torch.cuda.is_available():
lens = lens.cuda()
for l in range(max_len + 1): # +1 because of EOS
out, state, alpha = self._lstm_loop(state, self.embedding(toks[-1]), context, mask)
# Do argmax (since we're doing greedy decoding)
toks.append(out.max(1)[1].squeeze(1))
lens[(toks[-1].data == self.eos_token) & (lens == 0)] = l+1
if all(lens):
break
lens[lens == 0] = max_len+1
return torch.stack(toks, 0), lens
评论列表
文章目录