def collate_fn(data):
"""Build mini-batch tensors from a list of (image, caption) tuples.
Args:
data: list of (image, caption) tuple.
- image: torch tensor of shape (3, 256, 256).
- caption: torch tensor of shape (?); variable length.
Returns:
images: torch tensor of shape (batch_size, 3, 256, 256).
targets: torch tensor of shape (batch_size, padded_length).
lengths: list; valid length for each padded caption.
"""
# Sort a data list by caption length
data.sort(key=lambda x: len(x[1]), reverse=True)
images, captions, ids, img_ids = zip(*data)
# Merge images (convert tuple of 3D tensor to 4D tensor)
images = torch.stack(images, 0)
# Merget captions (convert tuple of 1D tensor to 2D tensor)
lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions), max(lengths)).long()
for i, cap in enumerate(captions):
end = lengths[i]
targets[i, :end] = cap[:end]
return images, targets, lengths, ids
评论列表
文章目录