def train_step(sess, train_op, global_step, train_step_kwargs):
"""Function that takes a gradient step and specifies whether to stop.
Args:
sess: The current session.
train_op: A dictionary of `Operation` that evaluates the gradients and returns the
total loss (for first) in case of iter_size > 1.
global_step: A `Tensor` representing the global training step.
train_step_kwargs: A dictionary of keyword arguments.
Returns:
The total loss and a boolean indicating whether or not to stop training.
"""
start_time = time.time()
if FLAGS.iter_size == 1:
# for debugging specific endpoint values,
# set the train file to one image and use
# pdb here
# import pdb
# pdb.set_trace()
if FLAGS.profile_iterations:
run_options = tf.RunOptions(
trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
total_loss, np_global_step = sess.run([train_op, global_step],
options=run_options,
run_metadata=run_metadata)
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open(os.path.join(FLAGS.train_dir,
'timeline_%08d.json' % np_global_step), 'w') as f:
f.write(ctf)
else:
total_loss, np_global_step = sess.run([train_op, global_step])
else:
for j in range(FLAGS.iter_size-1):
sess.run([train_op[j]])
total_loss, np_global_step = sess.run(
[train_op[FLAGS.iter_size-1], global_step])
time_elapsed = time.time() - start_time
if 'should_log' in train_step_kwargs:
if sess.run(train_step_kwargs['should_log']):
logging.info('%s: global step %d: loss = %.4f (%.2f sec)',
datetime.now(), np_global_step, total_loss, time_elapsed)
if 'should_stop' in train_step_kwargs:
should_stop = sess.run(train_step_kwargs['should_stop'])
else:
should_stop = False
return total_loss, should_stop
评论列表
文章目录