def __init__(self, images, labels, dtype=tf.float32):
dtype = tf.as_dtype(dtype).base_dtype
if dtype is not tf.float32:
raise TypeError('Invalid image dtype %r, expected float32' %dtype)
assert images.shape[0] == labels.shape[0], ('images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
self._num_examples = images.shape[0]
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
评论列表
文章目录