cifar10_input_data.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:dreamscape 作者: themattinthehatt 项目源码 文件源码
def __init__(self, images, labels, fake_data=False, one_hot=False,
                 dtype=tf.float32):
        """Construct a DataSet.
        one_hot arg is used only if fake_data is true.  `dtype` can be either
        `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
        `[0, 1]`.
        """

        dtype = tf.as_dtype(dtype).base_dtype
        if dtype not in (tf.uint8, tf.float32):
            raise TypeError(
                'Invalid image dtype %r, expected uint8 or float32' % dtype)
        if fake_data:
            self._num_examples = 10000
            self.one_hot = one_hot
        else:
            assert images.shape[0] == labels.shape[0], (
                'images.shape: %s labels.shape: %s' % (images.shape,
                                                       labels.shape))
        self._num_examples = images.shape[0]
        self._width = images.shape[1]
        self._height = images.shape[2]
        self._depth = images.shape[3]

        # Convert shape from [num examples, rows, columns, depth]
        # to [num examples, rows*columns*depth]
        assert images.shape[3] == IMAGE_DEPTH
        images = images.reshape(
            images.shape[0],
            images.shape[1] * images.shape[2] * images.shape[3])
        if dtype == tf.float32:
            # Convert from [0, 255] -> [0.0, 1.0].
            images = images.astype(np.float32)
            images = np.multiply(images, 1.0 / 255.0)

        self._images = images
        self._labels = labels
        self._epochs_completed = 0
        self._index_in_epoch = 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号