eval.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tf_datasets 作者: tmattio 项目源码 文件源码
def main(args):
    # load the dataset
    mnist = tfd.get_dataset('mnist', FLAGS.data_dir)
    dataset = mnist.load('validation')

    # load batch
    images, labels = load_batch(
        dataset,
        FLAGS.batch_size)

    # get the model prediction
    predictions = lenet(images)

    # convert prediction values for each class into single class prediction
    predictions = tf.to_int64(tf.argmax(predictions, 1))

    # streaming metrics to evaluate
    metrics_to_values, metrics_to_updates = metrics.aggregate_metric_map({
        'mse': metrics.streaming_mean_squared_error(predictions, labels),
        'accuracy': metrics.streaming_accuracy(predictions, labels),
    })

    # write the metrics as summaries
    for metric_name, metric_value in metrics_to_values.iteritems():
        tf.summary.scalar(metric_name, metric_value)

    # evaluate on the model saved at the checkpoint directory
    # evaluate every eval_interval_secs
    slim.evaluation.evaluation_loop(
        '',
        FLAGS.checkpoint_dir,
        FLAGS.log_dir,
        num_evals=FLAGS.num_evals,
        eval_op=metrics_to_updates.values(),
        eval_interval_secs=FLAGS.eval_interval_secs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号