train_cifar.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:TensorNet-TF 作者: timgaripov 项目源码 文件源码
def do_eval(sess,
            eval_correct,
            loss,
            data_set):
    """Runs one evaluation against the full epoch of data.
    Args:
        sess: The session in which the model has been trained.
        eval_correct: The Tensor that returns the number of correct predictions.
        images_placeholder: The images placeholder.
        labels_placeholder: The labels placeholder.
        data_set: The set of images and labels to evaluate, from
        input_data.read_data_sets().
    """
    # And run one epoch of eval.
    true_count = 0  # Counts the number of correct predictions.
    steps_per_epoch = data_set.num_examples // FLAGS.batch_size
    num_examples = steps_per_epoch * FLAGS.batch_size
    sum_loss = 0.0
    for step in xrange(steps_per_epoch):
        feed_dict = fill_feed_dict(data_set.next_batch(FLAGS.batch_size),
                                   train_phase=False)
        res = sess.run([loss, eval_correct], feed_dict=feed_dict)
        sum_loss += res[0]
        true_count += res[1]
    precision = true_count / num_examples
    avg_loss = sum_loss / (num_examples / FLAGS.batch_size)
    print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f  Loss: %.2f' %
          (num_examples, true_count, precision, avg_loss))
    return precision, avg_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号