def do_evaluation():
# load data early to get node_type_num
ds = data.load_dataset('data/statements')
hyper.node_type_num = len(ds.word2int)
(compiler, _, _, _, raw_accuracy, batch_size_op) = build_model()
# restorer for embedding matrix
embedding_path = tf.train.latest_checkpoint(hyper.embedding_dir)
if embedding_path is None:
raise ValueError('Path to embedding checkpoint is incorrect: ' + hyper.embedding_dir)
# restorer for other variables
checkpoint_path = tf.train.latest_checkpoint(hyper.train_dir)
if checkpoint_path is None:
raise ValueError('Path to tbcnn checkpoint is incorrect: ' + hyper.train_dir)
restored_vars = tf.get_collection_ref('restored')
restored_vars.append(param.get('We'))
restored_vars.extend(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
embeddingRestorer = tf.train.Saver({'embedding/We': param.get('We')})
restorer = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
# train loop
total_size, test_gen = ds.get_split('test')
test_set = compiler.build_loom_inputs(test_gen)
with tf.Session() as sess:
# Restore embedding matrix first
embeddingRestorer.restore(sess, embedding_path)
# Restore others
restorer.restore(sess, checkpoint_path)
# Initialize other variables
gvariables = [v for v in tf.global_variables() if v not in tf.get_collection('restored')]
sess.run(tf.variables_initializer(gvariables))
num_epochs = 1 if not hyper.warm_up else 3
for shuffled in td.epochs(test_set, num_epochs):
logger.info('')
logger.info('======================= Evaluation ====================================')
accumulated_accuracy = 0.
start_time = default_timer()
for step, batch in enumerate(td.group_by_batches(shuffled, hyper.batch_size), 1):
feed_dict = {compiler.loom_input_tensor: batch}
accuracy_value, actual_bsize = sess.run([raw_accuracy, batch_size_op], feed_dict)
accumulated_accuracy += accuracy_value * actual_bsize
logger.info('evaluation in progress: running accuracy = %.2f, processed = %d / %d',
accuracy_value, (step - 1) * hyper.batch_size + actual_bsize, total_size)
duration = default_timer() - start_time
total_accuracy = accumulated_accuracy / total_size
logger.info('evaluation accumulated accuracy = %.2f%% (%.1f samples/sec; %.2f seconds)',
total_accuracy * 100, total_size / duration, duration)
logger.info('======================= Evaluation End =================================')
logger.info('')
评论列表
文章目录