sample.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:char-rnn-tf 作者: liusiqi43 项目源码 文件源码
def main(unused_args):
  with open(os.path.join(FLAGS.session_dir, 'labels.pkl')) as f:
    char_to_id = pickle.load(f)
  with open(os.path.join(FLAGS.session_dir, 'config.pkl')) as f:
    config = pickle.load(f)
  with tf.variable_scope('model'):
    m = CharRNN('infer', config)
  with tf.Session() as sess:
    tf.initialize_all_variables().run()
    saver = tf.train.Saver(tf.all_variables())
    ckpt = tf.train.get_checkpoint_state(FLAGS.session_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
      print(ckpt.model_checkpoint_path, 'restored')

      while True:
        seed = raw_input('seed:')
        start_time = time.time()
        print(m.sample(sess, char_to_id, FLAGS.num_steps, seed))
        print(FLAGS.num_steps / (time.time() - start_time), 'cps')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号