parallel.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:ray 作者: ray-project 项目源码 文件源码
def optimize(self, sess, batch_index, extra_ops=[], extra_feed_dict={},
                 file_writer=None):
        """Run a single step of SGD.

        Runs a SGD step over a slice of the preloaded batch with size given by
        self.per_device_batch_size and offset given by the batch_index
        argument.

        Updates shared model weights based on the averaged per-device
        gradients.

        Args:
            sess: TensorFlow session.
            batch_index: Offset into the preloaded data. This value must be
                between `0` and `tuples_per_device`. The amount of data to
                process is always fixed to `per_device_batch_size`.
            extra_ops: Extra ops to run with this step (e.g. for metrics).
            extra_feed_dict: Extra args to feed into this session run.
            file_writer: If specified, tf metrics will be written out using
                this.

        Returns:
            The outputs of extra_ops evaluated over the batch.
        """

        if file_writer:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        else:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
        run_metadata = tf.RunMetadata()

        feed_dict = {self._batch_index: batch_index}
        feed_dict.update(extra_feed_dict)
        outs = sess.run(
            [self._train_op] + extra_ops,
            feed_dict=feed_dict,
            options=run_options,
            run_metadata=run_metadata)

        if file_writer:
            trace = timeline.Timeline(step_stats=run_metadata.step_stats)
            trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"),
                              "w")
            trace_file.write(trace.generate_chrome_trace_format())
            file_writer.add_run_metadata(
                run_metadata, "sgd_train_{}".format(batch_index))

        return outs[1:]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号