def __init__(self, images, labels, fake_data=False):
if fake_data:
self._num_examples = 10000
else:
assert images.shape[0] == labels.shape[0], (
"images.shape: %s labels.shape: %s" % (images.shape,
labels.shape))
self._num_examples = images.shape[0]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
self.imageShape = images.shape[1:]
self.imageChannels = self.imageShape[2]
images = images.reshape(images.shape[0],
images.shape[1] * images.shape[2] * images.shape[3])
# Convert from [0, 255] -> [0.0, 1.0].
images = images.astype(numpy.float32)
images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
try:
if len(numpy.shape(self._labels)) == 1:
self._labels = dense_to_one_hot(self._labels,len(numpy.unique(self._labels)))
except:
traceback.print_exc()
self._epochs_completed = 0
self._index_in_epoch = 0
评论列表
文章目录