test.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:subword-lstm-lm 作者: claravania 项目源码 文件源码
def test(test_args):

    start = time.time()

    with open(os.path.join(test_args.save_dir, 'config.pkl'), 'rb') as f:
        args = pickle.load(f)

    args.save_dir = test_args.save_dir
    data_loader = TextLoader(args, train=False)
    test_data = data_loader.read_dataset(test_args.test_file)

    print(args.save_dir)
    print("Unit: " + args.unit)
    print("Composition: " + args.composition)

    args.word_vocab_size = data_loader.word_vocab_size
    if args.unit != "word":
        args.subword_vocab_size = data_loader.subword_vocab_size

    # Statistics of words
    print("Word vocab size: " + str(data_loader.word_vocab_size))

    # Statistics of sub units
    if args.unit != "word":
        print("Subword vocab size: " + str(data_loader.subword_vocab_size))
        if args.composition == "bi-lstm":
            if args.unit == "char":
                args.bilstm_num_steps = data_loader.max_word_len
                print("Max word length:", data_loader.max_word_len)
            elif args.unit == "char-ngram":
                args.bilstm_num_steps = data_loader.max_ngram_per_word
                print("Max ngrams per word:", data_loader.max_ngram_per_word)
            elif args.unit == "morpheme" or args.unit == "oracle":
                args.bilstm_num_steps = data_loader.max_morph_per_word
                print("Max morphemes per word", data_loader.max_morph_per_word)

    if args.unit == "word":
        lm_model = WordModel
    elif args.composition == "addition":
        lm_model = AdditiveModel
    elif args.composition == "bi-lstm":
        lm_model = BiLSTMModel
    else:
        sys.exit("Unknown unit or composition.")

    print("Begin testing...")
    with tf.Graph().as_default(), tf.Session() as sess:
        with tf.variable_scope("model"):
            mtest = lm_model(args, is_training=False, is_testing=True)

        # save only the last model
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=1)
        tf.initialize_all_variables().run()
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        test_perplexity = run_epoch(sess, mtest, test_data, data_loader, tf.no_op())
        print("Test Perplexity: %.3f" % test_perplexity)
        print("Test time: %.0f\n" % (time.time() - start))
        print("\n")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号