def benchmark_one_step(sess,
fetches,
step,
batch_size,
step_train_times,
trace_filename,
image_producer,
params,
summary_op=None):
"""Advance one step of benchmarking."""
if trace_filename is not None and step == -1:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
else:
run_options = None
run_metadata = None
summary_str = None
start_time = time.time()
if summary_op is None:
results = sess.run(fetches, options=run_options, run_metadata=run_metadata)
else:
(results, summary_str) = sess.run(
[fetches, summary_op], options=run_options, run_metadata=run_metadata)
if not params.forward_only:
lossval = results['total_loss']
else:
lossval = 0.
image_producer.notify_image_consumption()
train_time = time.time() - start_time
step_train_times.append(train_time)
if step >= 0 and (step == 0 or (step + 1) % params.display_every == 0):
log_str = '%i\t%s\t%.3f' % (
step + 1, get_perf_timing_str(batch_size, step_train_times), lossval)
if 'top_1_accuracy' in results:
log_str += '\t%.3f\t%.3f' % (results['top_1_accuracy'],
results['top_5_accuracy'])
log_fn(log_str)
if trace_filename is not None and step == -1:
log_fn('Dumping trace to %s' % trace_filename)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
with gfile.Open(trace_filename, 'w') as trace_file:
trace_file.write(trace.generate_chrome_trace_format(show_memory=True))
return summary_str
评论列表
文章目录