def reinforce_sample(self, x, max_length=30, temperature=1.0, argmax=False):
N, T = x.size(0), max_length
encoded = self.encoder(x)
y = torch.LongTensor(N, T).fill_(self.NULL)
done = torch.ByteTensor(N).fill_(0)
cur_input = Variable(x.data.new(N, 1).fill_(self.START))
h, c = None, None
self.multinomial_outputs = []
self.multinomial_probs = []
for t in range(T):
# logprobs is N x 1 x V
logprobs, h, c = self.decoder(encoded, cur_input, h0=h, c0=c)
logprobs = logprobs / temperature
probs = F.softmax(logprobs.view(N, -1)) # Now N x V
if argmax:
_, cur_output = probs.max(1)
else:
cur_output = probs.multinomial() # Now N x 1
self.multinomial_outputs.append(cur_output)
self.multinomial_probs.append(probs)
cur_output_data = cur_output.data.cpu()
not_done = logical_not(done)
y[:, t][not_done] = cur_output_data[not_done]
done = logical_or(done, cur_output_data.cpu() == self.END)
cur_input = cur_output
if done.sum() == N:
break
return Variable(y.type_as(x.data))
评论列表
文章目录