def MNIST_loader(root, image_size, normalize=True):
"""
Function to load torchvision dataset object based on just image size
Args:
root = If your dataset is downloaded and ready to use, mention the location of this folder. Else, the dataset will be downloaded to this location
image_size = Size of every image
normalize = Requirement to normalize the image. Default is true
"""
transformations = [transforms.Scale(image_size), transforms.ToTensor()]
if normalize == True:
transformations.append(transforms.Normalize((0.5, ), (0.5, )))
mnist_data = dset.MNIST(root=root, download=True, transform=transforms.Compose(transformations))
return mnist_data
评论列表
文章目录