evaluate.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:GestureRecognition 作者: gkchai 项目源码 文件源码
def main(_):
    assert FLAGS.train_dir, "--train_dir is required."
    if tf.gfile.Exists(FLAGS.summaries_dir):
        tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
    tf.gfile.MakeDirs(FLAGS.summaries_dir)

    config = configuration.Config()

    dataset_eval = loader.get_split(FLAGS.split_name, dataset_dir=FLAGS.data_dir)
    if FLAGS.preprocess_abs:
        preprocess_fn = tf.abs
    else:
        preprocess_fn = None

    # whther it is a 2d input
    is_2D = common.is_2D(FLAGS.model)

    series, labels, labels_one_hot = loader.load_batch(dataset_eval, batch_size=config.batch_size, is_2D=is_2D,
                                                          preprocess_fn=preprocess_fn)

    # Build lazy model
    model = common.convert_name_to_instance(FLAGS.model, config, 'eval')

    endpoints = model.build(inputs=series, is_training=False)
    predictions = tf.to_int64(tf.argmax(endpoints.logits, 1))

    slim.get_or_create_global_step()

    # Choose the metrics to compute:
    names_to_values, names_to_updates = metrics.aggregate_metric_map({
        'accuracy': metrics.streaming_accuracy(predictions, labels),
        'precision': metrics.streaming_precision(predictions, labels),
        'recall': metrics.streaming_recall(predictions, labels),
    })

    # Create the summary ops such that they also print out to std output:
    summary_ops = []
    for metric_name, metric_value in names_to_values.iteritems():
        op = tf.summary.scalar(metric_name, metric_value)
        op = tf.Print(op, [metric_value], metric_name)
        summary_ops.append(op)

    slim.evaluation.evaluation_loop(
        master='',
        checkpoint_dir=FLAGS.train_dir,
        logdir=FLAGS.summaries_dir,
        eval_op=names_to_updates.values(),
        num_evals=min(FLAGS.num_batches, dataset_eval.num_samples),
        eval_interval_secs=FLAGS.eval_interval_secs,
        max_number_of_evaluations=FLAGS.num_of_steps,
        summary_op=tf.summary.merge(summary_ops),
        session_config=config.session_config,
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号