def optimize(self, sess, batch_index, extra_ops=[], extra_feed_dict={},
file_writer=None):
"""Run a single step of SGD.
Runs a SGD step over a slice of the preloaded batch with size given by
self.per_device_batch_size and offset given by the batch_index
argument.
Updates shared model weights based on the averaged per-device
gradients.
Args:
sess: TensorFlow session.
batch_index: Offset into the preloaded data. This value must be
between `0` and `tuples_per_device`. The amount of data to
process is always fixed to `per_device_batch_size`.
extra_ops: Extra ops to run with this step (e.g. for metrics).
extra_feed_dict: Extra args to feed into this session run.
file_writer: If specified, tf metrics will be written out using
this.
Returns:
The outputs of extra_ops evaluated over the batch.
"""
if file_writer:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
else:
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
run_metadata = tf.RunMetadata()
feed_dict = {self._batch_index: batch_index}
feed_dict.update(extra_feed_dict)
outs = sess.run(
[self._train_op] + extra_ops,
feed_dict=feed_dict,
options=run_options,
run_metadata=run_metadata)
if file_writer:
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"),
"w")
trace_file.write(trace.generate_chrome_trace_format())
file_writer.add_run_metadata(
run_metadata, "sgd_train_{}".format(batch_index))
return outs[1:]
评论列表
文章目录