def __init__(self, pkl_path, shuffle=False, distort=False,
capacity=2000, image_per_thread=16):
self._shuffle = shuffle
self._distort = distort
with open(pkl_path, 'rb') as fd:
data = pickle.load(fd)
self._images = data['data'].reshape([-1, 3, 32, 32]).transpose((0, 2, 3, 1)).copy(order='C')
self._labels = data['labels'] # numpy 1-D array
self.size = len(self._labels)
self.queue = tf.FIFOQueue(shapes=[[32,32,3], []],
dtypes=[tf.float32, tf.int32],
capacity=capacity)
# self.queue = tf.RandomShuffleQueue(shapes=[[32,32,3], []],
# dtypes=[tf.float32, tf.int32],
# capacity=capacity,
# min_after_dequeue=min_after_dequeue)
self.dataX = tf.placeholder(dtype=tf.float32, shape=[None,32,32,3])
self.dataY = tf.placeholder(dtype=tf.int32, shape=[None,])
self.enqueue_op = self.queue.enqueue_many([self.dataX, self.dataY])
self.image_per_thread = image_per_thread
self._image_summary_added = False
评论列表
文章目录