def minibatch(self, dataset, subset, use_datasets, cache_data,
shift_ratio=-1):
"""Get synthetic image batches."""
del subset, use_datasets, cache_data, shift_ratio
input_shape = [self.batch_size, self.height, self.width, self.depth]
images = tf.truncated_normal(
input_shape,
dtype=self.dtype,
stddev=1e-1,
name='synthetic_images')
labels = tf.random_uniform(
[self.batch_size],
minval=0,
maxval=dataset.num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
# Note: This results in a H2D copy, but no computation
# Note: This avoids recomputation of the random values, but still
# results in a H2D copy.
images = tf.contrib.framework.local_variable(images, name='images')
labels = tf.contrib.framework.local_variable(labels, name='labels')
if self.num_splits == 1:
images_splits = [images]
labels_splits = [labels]
else:
images_splits = tf.split(images, self.num_splits, 0)
labels_splits = tf.split(labels, self.num_splits, 0)
return images_splits, labels_splits
评论列表
文章目录