def load_data(self, limit_data, type='cifar10'):
if MyConfig.cache_data is None:
if type == 'cifar10':
(train_x, train_y), (test_x, test_y) = cifar10.load_data()
elif type == 'mnist':
(train_x, train_y), (test_x, test_y) = mnist.load_data()
elif type == 'cifar100':
(train_x, train_y), (test_x, test_y) = cifar100.load_data(label_mode='fine')
elif type == 'svhn':
(train_x, train_y), (test_x, test_y) = load_data_svhn()
train_x, mean_img = self._preprocess_input(train_x, None)
test_x, _ = self._preprocess_input(test_x, mean_img)
train_y, test_y = map(self._preprocess_output, [train_y, test_y])
res = {'train_x': train_x, 'train_y': train_y, 'test_x': test_x, 'test_y': test_y}
for key, val in res.iteritems():
res[key] = MyConfig._limit_data(val, limit_data)
MyConfig.cache_data = res
self.dataset = MyConfig.cache_data
评论列表
文章目录