utils.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:sourceseparation_misc 作者: ycemsubakan 项目源码 文件源码
def get_loaders(loader_batchsize, **kwargs):
    arguments=kwargs['arguments']
    data = arguments.data

    if data == 'mnist':
        kwargs = {'num_workers': 1, 'pin_memory': True} if arguments.cuda else {}
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               #transforms.Normalize((0,), (1,))
                           ])),
            batch_size=loader_batchsize, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               #transforms.Normalize((7,), (0.3081,))
                           ])),
            batch_size=loader_batchsize, shuffle=True, **kwargs)

    return train_loader, test_loader
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号