run_test.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def run():
    if len(sys.argv) < 3:
        print("** Usage: python3 " + sys.argv[0] + " <<Model Directory>> <<Test Set>>")
        sys.exit(1)

    np.random.seed(42)
    model_dir = sys.argv[1]
    config = Config.load(['./default.conf', os.path.join(model_dir, 'model.conf')])
    model = create_model(config)
    test_data = load_data(sys.argv[2], config.dictionary, config.grammar, config.max_length)
    print("unknown", unknown_tokens)

    with tf.Graph().as_default():
        tf.set_random_seed(1234)
        with tf.device('/cpu:0'):
            model.build()

            test_eval = Seq2SeqEvaluator(model, config.grammar, test_data, 'test', config.reverse_dictionary, beam_size=config.beam_size, batch_size=config.batch_size)
            loader = tf.train.Saver()

            with tf.Session() as sess:
                loader.restore(sess, os.path.join(model_dir, 'best'))

                #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
                #sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

                test_eval.eval(sess, save_to_file=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号