evaluate.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:squeezenet 作者: mtreml 项目源码 文件源码
def evaluate():
  """Eval for a number of steps."""
  with tf.Graph().as_default() as g:

    # Get images and labels.
    images, labels = architecture.inputs(phase=FLAGS.phase)

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = architecture.inference(images, train=False)

    # adapt logits        
    logits = tf.reshape(logits, (-1, NUM_CLASSES))
    epsilon = tf.constant(value=1e-4)
    logits = logits + epsilon

    # predict
    predictions = tf.argmax(logits, dimension=1)        
    labels = tf.cast(tf.reshape(labels, shape=predictions.get_shape()), dtype=tf.int64)

    # compute accuracy    
    correct_predictions = tf.equal(predictions, labels)
    accuracy = tf.reduce_mean(tf.cast(correct_predictions, dtype=tf.float32))

    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(
        architecture.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()
    summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g)

    tf.initialize_all_variables()

    while True:
      eval_once(saver, summary_writer, accuracy, summary_op)
      if FLAGS.run_once:
        break
      time.sleep(FLAGS.eval_interval_secs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号