def get_loader_single(data_name, split, root, json, vocab, transform,
batch_size=100, shuffle=True,
num_workers=2, ids=None, collate_fn=collate_fn):
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
if 'coco' in data_name:
# COCO custom dataset
dataset = CocoDataset(root=root,
json=json,
vocab=vocab,
transform=transform, ids=ids)
elif 'f8k' in data_name or 'f30k' in data_name:
dataset = FlickrDataset(root=root,
split=split,
json=json,
vocab=vocab,
transform=transform)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
num_workers=num_workers,
collate_fn=collate_fn)
return data_loader
评论列表
文章目录