dataloader.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:ExperimentPackage_PyTorch 作者: ICEORY 项目源码 文件源码
def mnist(self):
        norm_mean = [0.1307]
        norm_std = [0.3081]
        train_loader = torch.utils.data.DataLoader(
            dsets.MNIST("/home/dataset/mnist", train=True, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize(norm_mean, norm_std)
                        ])),
            batch_size=self.train_batch_size, shuffle=True,
            num_workers=self.n_threads,
            pin_memory=False
        )
        test_loader = torch.utils.data.DataLoader(
            dsets.MNIST("/home/dataset/mnist", train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(norm_mean, norm_std)
            ])),
            batch_size=self.test_batch_size, shuffle=True,
            num_workers=self.n_threads,
            pin_memory=False
        )
        return train_loader, test_loader
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号