tbcnn.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:tensorflow-tbcnn 作者: Aetf 项目源码 文件源码
def do_evaluation():
    # load data early to get node_type_num
    ds = data.load_dataset('data/statements')
    hyper.node_type_num = len(ds.word2int)

    (compiler, _, _, _, raw_accuracy, batch_size_op) = build_model()

    # restorer for embedding matrix
    embedding_path = tf.train.latest_checkpoint(hyper.embedding_dir)
    if embedding_path is None:
        raise ValueError('Path to embedding checkpoint is incorrect: ' + hyper.embedding_dir)

    # restorer for other variables
    checkpoint_path = tf.train.latest_checkpoint(hyper.train_dir)
    if checkpoint_path is None:
        raise ValueError('Path to tbcnn checkpoint is incorrect: ' + hyper.train_dir)

    restored_vars = tf.get_collection_ref('restored')
    restored_vars.append(param.get('We'))
    restored_vars.extend(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
    embeddingRestorer = tf.train.Saver({'embedding/We': param.get('We')})
    restorer = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

    # train loop
    total_size, test_gen = ds.get_split('test')
    test_set = compiler.build_loom_inputs(test_gen)
    with tf.Session() as sess:
        # Restore embedding matrix first
        embeddingRestorer.restore(sess, embedding_path)
        # Restore others
        restorer.restore(sess, checkpoint_path)
        # Initialize other variables
        gvariables = [v for v in tf.global_variables() if v not in tf.get_collection('restored')]
        sess.run(tf.variables_initializer(gvariables))

        num_epochs = 1 if not hyper.warm_up else 3
        for shuffled in td.epochs(test_set, num_epochs):
            logger.info('')
            logger.info('======================= Evaluation ====================================')
            accumulated_accuracy = 0.
            start_time = default_timer()
            for step, batch in enumerate(td.group_by_batches(shuffled, hyper.batch_size), 1):
                feed_dict = {compiler.loom_input_tensor: batch}
                accuracy_value, actual_bsize = sess.run([raw_accuracy, batch_size_op], feed_dict)
                accumulated_accuracy += accuracy_value * actual_bsize
                logger.info('evaluation in progress: running accuracy = %.2f, processed = %d / %d',
                            accuracy_value, (step - 1) * hyper.batch_size + actual_bsize, total_size)
            duration = default_timer() - start_time
            total_accuracy = accumulated_accuracy / total_size
            logger.info('evaluation accumulated accuracy = %.2f%% (%.1f samples/sec; %.2f seconds)',
                        total_accuracy * 100, total_size / duration, duration)
            logger.info('======================= Evaluation End =================================')
            logger.info('')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号