def main(args):
# load the dataset
mnist = tfd.get_dataset('mnist', FLAGS.data_dir)
dataset = mnist.load('validation')
# load batch
images, labels = load_batch(
dataset,
FLAGS.batch_size)
# get the model prediction
predictions = lenet(images)
# convert prediction values for each class into single class prediction
predictions = tf.to_int64(tf.argmax(predictions, 1))
# streaming metrics to evaluate
metrics_to_values, metrics_to_updates = metrics.aggregate_metric_map({
'mse': metrics.streaming_mean_squared_error(predictions, labels),
'accuracy': metrics.streaming_accuracy(predictions, labels),
})
# write the metrics as summaries
for metric_name, metric_value in metrics_to_values.iteritems():
tf.summary.scalar(metric_name, metric_value)
# evaluate on the model saved at the checkpoint directory
# evaluate every eval_interval_secs
slim.evaluation.evaluation_loop(
'',
FLAGS.checkpoint_dir,
FLAGS.log_dir,
num_evals=FLAGS.num_evals,
eval_op=metrics_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
评论列表
文章目录