def get_loaders(loader_batchsize, **kwargs):
arguments=kwargs['arguments']
data = arguments.data
if data == 'mnist':
kwargs = {'num_workers': 1, 'pin_memory': True} if arguments.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((0,), (1,))
])),
batch_size=loader_batchsize, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((7,), (0.3081,))
])),
batch_size=loader_batchsize, shuffle=True, **kwargs)
return train_loader, test_loader
评论列表
文章目录