def mnist(self):
norm_mean = [0.1307]
norm_std = [0.3081]
train_loader = torch.utils.data.DataLoader(
dsets.MNIST("/home/dataset/mnist", train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])),
batch_size=self.train_batch_size, shuffle=True,
num_workers=self.n_threads,
pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
dsets.MNIST("/home/dataset/mnist", train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])),
batch_size=self.test_batch_size, shuffle=True,
num_workers=self.n_threads,
pin_memory=False
)
return train_loader, test_loader
评论列表
文章目录