parallel.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:ray 作者: ray-project 项目源码 文件源码
def load_data(self, sess, inputs, full_trace=False):
        """Bulk loads the specified inputs into device memory.

        The shape of the inputs must conform to the shapes of the input
        placeholders this optimizer was constructed with.

        The data is split equally across all the devices. If the data is not
        evenly divisible by the batch size, excess data will be discarded.

        Args:
            sess: TensorFlow session.
            inputs: List of Tensors matching the input placeholders specified
                at construction time of this optimizer.
            full_trace: Whether to profile data loading.

        Returns:
            The number of tuples loaded per device.
        """

        feed_dict = {}
        assert len(self.input_placeholders) == len(inputs)
        for ph, arr in zip(self.input_placeholders, inputs):
            truncated_arr = make_divisible_by(arr, self.batch_size)
            feed_dict[ph] = truncated_arr
            truncated_len = len(truncated_arr)

        if full_trace:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        else:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
        run_metadata = tf.RunMetadata()

        sess.run(
            [t.init_op for t in self._towers],
            feed_dict=feed_dict,
            options=run_options,
            run_metadata=run_metadata)
        if full_trace:
            trace = timeline.Timeline(step_stats=run_metadata.step_stats)
            trace_file = open(os.path.join(self.logdir, "timeline-load.json"),
                              "w")
            trace_file.write(trace.generate_chrome_trace_format())

        tuples_per_device = truncated_len / len(self.devices)
        assert tuples_per_device > 0, \
            "Too few tuples per batch, trying increasing the training " \
            "batch size or decreasing the sgd batch size. Tried to split up " \
            "{} rows {}-ways in batches of {} (total across devices).".format(
                len(arr), len(self.devices), self.batch_size)
        assert tuples_per_device % self.per_device_batch_size == 0
        return tuples_per_device
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号