train.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:optnet 作者: locuslab 项目源码 文件源码
def get_loaders(args):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'mnist':
        trainLoader = torch.utils.data.DataLoader(
            dset.MNIST('data/mnist', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=args.batchSz, shuffle=True, **kwargs)
        testLoader = torch.utils.data.DataLoader(
            dset.MNIST('data/mnist', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
            batch_size=args.batchSz, shuffle=False, **kwargs)
    elif args.dataset == 'cifar-10':
        normMean = [0.49139968, 0.48215827, 0.44653124]
        normStd = [0.24703233, 0.24348505, 0.26158768]
        normTransform = transforms.Normalize(normMean, normStd)

        trainTransform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normTransform
        ])
        testTransform = transforms.Compose([
            transforms.ToTensor(),
            normTransform
        ])

        trainLoader = DataLoader(
            dset.CIFAR10(root='data/cifar', train=True, download=True,
                        transform=trainTransform),
            batch_size=args.batchSz, shuffle=True, **kwargs)
        testLoader = DataLoader(
            dset.CIFAR10(root='data/cifar', train=False, download=True,
                        transform=testTransform),
            batch_size=args.batchSz, shuffle=False, **kwargs)
    else:
        assert(False)

    return trainLoader, testLoader
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号