evaluate.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:Image-Captioning 作者: zsdonghao 项目源码 文件源码
def run_once(global_step, target_cross_entropy_losses, target_cross_entropy_loss_weights, saver, summary_writer, summary_op):
  """Evaluates the latest model checkpoint.

  Args:
    model: Instance of ShowAndTellModel; the model to evaluate.
    saver: Instance of tf.train.Saver for restoring model Variables.
    summary_writer: Instance of SummaryWriter.
    summary_op: Op for generating model summaries.
  """
  # The lastest ckpt
  model_path = tf.train.latest_checkpoint(checkpoint_dir)
  # print(model_path)   # /home/dsigpu4/Samba/im2txt/model/train_tl/model.ckpt-20000
  # exit()
  if not model_path:
    tf.logging.info("Skipping evaluation. No checkpoint found in: %s",
                    checkpoint_dir)
    return

  with tf.Session() as sess:
    # Load model from checkpoint.
    tf.logging.info("Loading model from checkpoint: %s", model_path)
    saver.restore(sess, model_path)
    # global_step = tf.train.global_step(sess, model.global_step.name)
    step = tf.train.global_step(sess, global_step.name)
    tf.logging.info("Successfully loaded %s at global step = %d.",
                    # os.path.basename(model_path), global_step)
                    os.path.basename(model_path), step)
    # if global_step < min_global_step:
    if step < min_global_step:
    #   tf.logging.info("Skipping evaluation. Global step = %d < %d", global_step,
      tf.logging.info("Skipping evaluation. Global step = %d < %d", step,
                      min_global_step)
      return

    # Start the queue runners.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    # Run evaluation on the latest checkpoint.
    try:
        evaluate_model(
             sess=sess,
             target_cross_entropy_losses=target_cross_entropy_losses,
             target_cross_entropy_loss_weights=target_cross_entropy_loss_weights,
             global_step=step,
             summary_writer=summary_writer,
             summary_op=summary_op)
    except Exception, e:  # pylint: disable=broad-except
      tf.logging.error("Evaluation failed.")
      coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号