test.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:rnn-lm-tensorflow 作者: claravania 项目源码 文件源码
def test(test_args):
    start = time.time()
    with open(os.path.join(test_args.save_dir, 'config.pkl')) as f:
        args = cPickle.load(f)
    data_loader = TextLoader(args, train=False)
    test_data = data_loader.read_dataset(test_args.test_file)

    args.word_vocab_size = data_loader.word_vocab_size
    print "Word vocab size: " + str(data_loader.word_vocab_size) + "\n"

    # Model
    lm_model = WordLM

    print "Begin testing..."
    # If using gpu:
    # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
    # gpu_config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
    # add parameters to the tf session -> tf.Session(config=gpu_config)
    with tf.Graph().as_default(), tf.Session() as sess:
        initializer = tf.random_uniform_initializer(-args.init_scale, args.init_scale)
        with tf.variable_scope("model", reuse=None, initializer=initializer):
            mtest = lm_model(args, is_training=False, is_testing=True)

        # save only the last model
        saver = tf.train.Saver(tf.all_variables())
        tf.initialize_all_variables().run()
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        test_perplexity = run_epoch(sess, mtest, test_data, data_loader, tf.no_op())
        print("Test Perplexity: %.3f" % test_perplexity)
        print("Test time: %.0f" % (time.time() - start))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号