def generate(decoder,
prime_str='int ',
predict_len=100,
temperature=0.35,
cuda=False,
args=None,
hidden=None):
prime_input = Variable(char_tensor(prime_str).unsqueeze(0))
if not hidden:
hidden = decoder.init_hidden(1)
prime_input = Variable(char_tensor(prime_str).unsqueeze(0))
if cuda:
hidden = hidden.cuda()
prime_input = prime_input.cuda()
# Use priming string to "build up" hidden state
for p in range(len(prime_str) - 1):
_, hidden = decoder(prime_input[:,p], hidden)
predicted = ''
inp = prime_input[:,-1]
p_list = []
for p in range(predict_len):
output, hidden = decoder(inp, hidden)
# Sample from the network as a multinomial distribution
output_dist = output.data.view(-1).div(temperature).exp()
top_i = torch.multinomial(output_dist, 1)[0]
p_list.append(top_i)
# Add predicted character to string and use as next input
predicted_char = all_characters[top_i]
predicted += predicted_char
inp = Variable(char_tensor(predicted_char).unsqueeze(0))
if cuda: inp = inp.cuda()
# print (p_list)
return predicted, hidden
评论列表
文章目录