def generate_text(model, seed, length=512, top_n=10):
"""
generates text of specified length from trained model
with given seed character sequence.
"""
logger.info("generating %s characters from top %s choices.", length, top_n)
logger.info('generating with seed: "%s".', seed)
generated = seed
encoded = encode_text(seed).astype(np.int32)
model.predictor.reset_state()
with chainer.using_config("train", False), chainer.no_backprop_mode():
for idx in encoded[:-1]:
x = Variable(np.array([idx]))
# input shape: [1]
# set internal states
model.predictor(x)
next_index = encoded[-1]
for i in range(length):
x = Variable(np.array([next_index], dtype=np.int32))
# input shape: [1]
probs = F.softmax(model.predictor(x))
# output shape: [1, vocab_size]
next_index = sample_from_probs(probs.data.squeeze(), top_n)
# append to sequence
generated += ID2CHAR[next_index]
logger.info("generated text: \n%s\n", generated)
return generated
评论列表
文章目录