train_image_classifier.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ActionVLAD 作者: rohitgirdhar 项目源码 文件源码
def train_step(sess, train_op, global_step, train_step_kwargs):
  """Function that takes a gradient step and specifies whether to stop.
  Args:
    sess: The current session.
    train_op: A dictionary of `Operation` that evaluates the gradients and returns the
      total loss (for first) in case of iter_size > 1.
    global_step: A `Tensor` representing the global training step.
    train_step_kwargs: A dictionary of keyword arguments.
  Returns:
    The total loss and a boolean indicating whether or not to stop training.
  """
  start_time = time.time()
  if FLAGS.iter_size == 1:
    # for debugging specific endpoint values,
    # set the train file to one image and use
    # pdb here
    # import pdb
    # pdb.set_trace()
    if FLAGS.profile_iterations:
      run_options = tf.RunOptions(
          trace_level=tf.RunOptions.FULL_TRACE)
      run_metadata = tf.RunMetadata()
      total_loss, np_global_step = sess.run([train_op, global_step],
          options=run_options,
          run_metadata=run_metadata)
      tl = timeline.Timeline(run_metadata.step_stats)
      ctf = tl.generate_chrome_trace_format()
      with open(os.path.join(FLAGS.train_dir,
                             'timeline_%08d.json' % np_global_step), 'w') as f:
        f.write(ctf)
    else:
      total_loss, np_global_step = sess.run([train_op, global_step])
  else:
    for j in range(FLAGS.iter_size-1):
      sess.run([train_op[j]])
    total_loss, np_global_step = sess.run(
        [train_op[FLAGS.iter_size-1], global_step])
  time_elapsed = time.time() - start_time

  if 'should_log' in train_step_kwargs:
    if sess.run(train_step_kwargs['should_log']):
      logging.info('%s: global step %d: loss = %.4f (%.2f sec)',
                   datetime.now(), np_global_step, total_loss, time_elapsed)

  if 'should_stop' in train_step_kwargs:
    should_stop = sess.run(train_step_kwargs['should_stop'])
  else:
    should_stop = False

  return total_loss, should_stop
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号