def _make_dataloaders(train_set, valid_set, test_set, train_size, valid_size, batch_size):
# Split training into train and validation
indices = torch.randperm(len(train_set))
train_indices = indices[:len(indices)-valid_size][:train_size or None]
valid_indices = indices[len(indices)-valid_size:] if valid_size else None
train_loader = torch.utils.data.DataLoader(train_set, pin_memory=True, batch_size=batch_size,
sampler=SubsetRandomSampler(train_indices))
test_loader = torch.utils.data.DataLoader(test_set, pin_memory=True, batch_size=batch_size)
if valid_size:
valid_loader = torch.utils.data.DataLoader(valid_set, pin_memory=True, batch_size=batch_size,
sampler=SubsetRandomSampler(valid_indices))
else:
valid_loader = None
return train_loader, valid_loader, test_loader
评论列表
文章目录