def main(unused_args):
with open(os.path.join(FLAGS.session_dir, 'labels.pkl')) as f:
char_to_id = pickle.load(f)
with open(os.path.join(FLAGS.session_dir, 'config.pkl')) as f:
config = pickle.load(f)
with tf.variable_scope('model'):
m = CharRNN('infer', config)
with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(FLAGS.session_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path, 'restored')
while True:
seed = raw_input('seed:')
start_time = time.time()
print(m.sample(sess, char_to_id, FLAGS.num_steps, seed))
print(FLAGS.num_steps / (time.time() - start_time), 'cps')
评论列表
文章目录