Config.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:NetworkCompress 作者: luzai 项目源码 文件源码
def load_data(self, limit_data, type='cifar10'):
        if MyConfig.cache_data is None:
            if type == 'cifar10':
                (train_x, train_y), (test_x, test_y) = cifar10.load_data()
            elif type == 'mnist':
                (train_x, train_y), (test_x, test_y) = mnist.load_data()
            elif type == 'cifar100':
                (train_x, train_y), (test_x, test_y) = cifar100.load_data(label_mode='fine')
            elif type == 'svhn':
                (train_x, train_y), (test_x, test_y) = load_data_svhn()

            train_x, mean_img = self._preprocess_input(train_x, None)
            test_x, _ = self._preprocess_input(test_x, mean_img)

            train_y, test_y = map(self._preprocess_output, [train_y, test_y])

            res = {'train_x': train_x, 'train_y': train_y, 'test_x': test_x, 'test_y': test_y}

            for key, val in res.iteritems():
                res[key] = MyConfig._limit_data(val, limit_data)
            MyConfig.cache_data = res

        self.dataset = MyConfig.cache_data
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号