def __init__(self, dataset, train_batch_size, test_batch_size, n_threads=4, ten_crop=False):
self.dataset = dataset
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.n_threads = n_threads
self.ten_crop = ten_crop
if self.dataset == "cifar10" or self.dataset == "cifar100":
print "|===>Creating Cifar Data Loader"
self.train_loader, self.test_loader = self.cifar(dataset=self.dataset)
elif self.dataset == "mnist":
print "|===>Creating MNIST Data Loader"
self.train_loader, self.test_loader = self.mnist()
else:
assert False, "invalid data set"
评论列表
文章目录