def sample(self, num_samples, start_letter=0):
"""
Samples the network and returns num_samples samples of length max_seq_len.
Outputs: samples, hidden
- samples: num_samples x max_seq_length (a sampled sequence in each row)
"""
samples = torch.zeros(num_samples, self.max_seq_len).type(torch.LongTensor)
h = self.init_hidden(num_samples)
inp = autograd.Variable(torch.LongTensor([start_letter]*num_samples))
if self.gpu:
samples = samples.cuda()
inp = inp.cuda()
for i in range(self.max_seq_len):
out, h = self.forward(inp, h) # out: num_samples x vocab_size
out = torch.multinomial(torch.exp(out), 1) # num_samples x 1 (sampling from each row)
samples[:, i] = out.data
inp = out.view(-1)
return samples
评论列表
文章目录