preprocessing.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:benchmarks 作者: tensorflow 项目源码 文件源码
def minibatch(self, dataset, subset, use_datasets, cache_data,
                shift_ratio=-1):
    """Get synthetic image batches."""
    del subset, use_datasets, cache_data, shift_ratio
    input_shape = [self.batch_size, self.height, self.width, self.depth]
    images = tf.truncated_normal(
        input_shape,
        dtype=self.dtype,
        stddev=1e-1,
        name='synthetic_images')
    labels = tf.random_uniform(
        [self.batch_size],
        minval=0,
        maxval=dataset.num_classes - 1,
        dtype=tf.int32,
        name='synthetic_labels')
    # Note: This results in a H2D copy, but no computation
    # Note: This avoids recomputation of the random values, but still
    #         results in a H2D copy.
    images = tf.contrib.framework.local_variable(images, name='images')
    labels = tf.contrib.framework.local_variable(labels, name='labels')
    if self.num_splits == 1:
      images_splits = [images]
      labels_splits = [labels]
    else:
      images_splits = tf.split(images, self.num_splits, 0)
      labels_splits = tf.split(labels, self.num_splits, 0)
    return images_splits, labels_splits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号