demo.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:efficient_densenet_pytorch 作者: gpleiss 项目源码 文件源码
def _make_dataloaders(train_set, valid_set, test_set, train_size, valid_size, batch_size):
    # Split training into train and validation
    indices = torch.randperm(len(train_set))
    train_indices = indices[:len(indices)-valid_size][:train_size or None]
    valid_indices = indices[len(indices)-valid_size:] if valid_size else None

    train_loader = torch.utils.data.DataLoader(train_set, pin_memory=True, batch_size=batch_size,
                                               sampler=SubsetRandomSampler(train_indices))
    test_loader = torch.utils.data.DataLoader(test_set, pin_memory=True, batch_size=batch_size)
    if valid_size:
        valid_loader = torch.utils.data.DataLoader(valid_set, pin_memory=True, batch_size=batch_size,
                                                   sampler=SubsetRandomSampler(valid_indices))
    else:
        valid_loader = None

    return train_loader, valid_loader, test_loader
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号