rhn_train.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:RecurrentHighwayNetworks 作者: julian121266 项目源码 文件源码
def evaluate_mc(data_path, dataset, load_model, mc_steps, seed):
  """Evaluate the model on the given data using MC averaging."""
  ex.commands['print_config']()
  print("MC Evaluation of model:", load_model)
  assert mc_steps > 0
  reader, (train_data, valid_data, test_data, _) = get_data(data_path, dataset)

  config = get_config()
  val_config = deepcopy(config)
  test_config = deepcopy(config)
  test_config.batch_size = test_config.num_steps = 1
  with tf.Session() as session:
    initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale)
    with tf.variable_scope("model", reuse=None, initializer=initializer):
      _ = Model(is_training=True, config=config)
    with tf.variable_scope("model", reuse=True, initializer=initializer):
      _ = Model(is_training=False, config=val_config)
      mtest = Model(is_training=False, config=test_config)
    tf.initialize_all_variables()
    saver = tf.train.Saver()
    saver.restore(session, load_model)

    print("Testing on non-batched Test ...")
    test_perplexity = run_mc_epoch(seed, session, mtest, test_data, tf.no_op(), test_config, mc_steps, verbose=True)
    print("Full Test Perplexity: %.3f, Bits: %.3f" % (test_perplexity, np.log2(test_perplexity)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号