def read_data_files(self, subset='train'):
"""Reads from data file and returns images and labels in a numpy array."""
assert self.data_dir, ('Cannot call `read_data_files` when using synthetic '
'data')
if subset == 'train':
filenames = [os.path.join(self.data_dir, 'data_batch_%d' % i)
for i in xrange(1, 6)]
elif subset == 'validation':
filenames = [os.path.join(self.data_dir, 'test_batch')]
else:
raise ValueError('Invalid data subset "%s"' % subset)
inputs = []
for filename in filenames:
with gfile.Open(filename, 'r') as f:
inputs.append(cPickle.load(f))
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
all_images = np.concatenate(
[each_input['data'] for each_input in inputs]).astype(np.float32)
all_labels = np.concatenate(
[each_input['labels'] for each_input in inputs])
return all_images, all_labels
评论列表
文章目录