def sample(self, inputs, max_length):
targets, init_states = self.initialize(inputs, eval=False)
emb, output, hidden, context = init_states
outputs = []
samples = []
batch_size = targets.size(1)
num_eos = targets[0].data.byte().new(batch_size).zero_()
for i in range(max_length):
output, hidden = self.decoder.step(emb, output, hidden, context)
outputs.append(output)
dist = F.softmax(self.generator(output))
sample = dist.multinomial(1, replacement=False).view(-1).data
samples.append(sample)
# Stop if all sentences reach EOS.
num_eos |= (sample == lib.Constants.EOS)
if num_eos.sum() == batch_size: break
emb = self.decoder.word_lut(Variable(sample))
outputs = torch.stack(outputs)
samples = torch.stack(samples)
return samples, outputs
评论列表
文章目录