def load_data(self, sess, inputs, full_trace=False):
"""Bulk loads the specified inputs into device memory.
The shape of the inputs must conform to the shapes of the input
placeholders this optimizer was constructed with.
The data is split equally across all the devices. If the data is not
evenly divisible by the batch size, excess data will be discarded.
Args:
sess: TensorFlow session.
inputs: List of Tensors matching the input placeholders specified
at construction time of this optimizer.
full_trace: Whether to profile data loading.
Returns:
The number of tuples loaded per device.
"""
feed_dict = {}
assert len(self.input_placeholders) == len(inputs)
for ph, arr in zip(self.input_placeholders, inputs):
truncated_arr = make_divisible_by(arr, self.batch_size)
feed_dict[ph] = truncated_arr
truncated_len = len(truncated_arr)
if full_trace:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
else:
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
run_metadata = tf.RunMetadata()
sess.run(
[t.init_op for t in self._towers],
feed_dict=feed_dict,
options=run_options,
run_metadata=run_metadata)
if full_trace:
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file = open(os.path.join(self.logdir, "timeline-load.json"),
"w")
trace_file.write(trace.generate_chrome_trace_format())
tuples_per_device = truncated_len / len(self.devices)
assert tuples_per_device > 0, \
"Too few tuples per batch, trying increasing the training " \
"batch size or decreasing the sgd batch size. Tried to split up " \
"{} rows {}-ways in batches of {} (total across devices).".format(
len(arr), len(self.devices), self.batch_size)
assert tuples_per_device % self.per_device_batch_size == 0
return tuples_per_device
评论列表
文章目录