mnist.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tf-tutorial 作者: zchen0211 项目源码 文件源码
def __init__(self, images, labels, fake_data=False, one_hot=False,
               dtype=tf.float32, trim_flag=False):
    """Construct a DataSet.

    one_hot arg is used only if fake_data is true.  `dtype` can be either
    `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
    `[0, 1]`.
    """
    dtype = tf.as_dtype(dtype).base_dtype
    if dtype not in (tf.uint8, tf.float32):
      raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
                      dtype)
    if fake_data:
      self._num_examples = 10000
      self.one_hot = one_hot
    else:
      assert images.shape[0] == labels.shape[0], (
          'images.shape: %s labels.shape: %s' % (images.shape,
                                                 labels.shape))
      self._num_examples = images.shape[0]

      # Convert shape from [num examples, rows, columns, depth]
      # to [num examples, rows*columns] (assuming depth == 1)
      assert images.shape[3] == 1
      # images = images.reshape(images.shape[0],
      #                        images.shape[1] * images.shape[2])
      if dtype == tf.float32:
        # Convert from [0, 255] -> [0.0, 1.0].
        images = images.astype(np.float32)
        images = np.multiply(images, 1.0 / 255.0)

      # log.info(str(images.max()))
      log.info(str(images.shape))  # (50000, 28, 28, 1)
      log.info(str(labels.shape))

      # if trim_flag:
      # images = images[:500]
      # labels = labels[:500]

      # add generated data
      '''gen_data = np.load('mnist-gen')
      images = np.concatenate((images, gen_data['image']))
      labels = np.concatenate((labels, gen_data['label']))'''

      self._num_examples = images.shape[0]
      log.info('using %d data for training' % self._num_examples )

    self._images = images
    self._labels = labels
    self._epochs_completed = 0
    self._index_in_epoch = 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号