def sample(args):
with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
saved_args = cPickle.load(f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
chars, vocab = cPickle.load(f)
model = Model(saved_args, training=False)
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
ret, hidden = model.sample(sess, chars, vocab, args.n, args.prime,
args.sample)#.encode('utf-8'))
print("Number of characters generated: ", len(ret))
for i in range(len(ret)):
print("Generated character: ", ret[i])
print("Assosciated hidden state:" , hidden[i])
评论列表
文章目录