def detection_collate(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations
Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on 0 dim
"""
targets = []
imgs = []
for _, sample in enumerate(batch):
for _, tup in enumerate(sample):
#pdb.set_trace()
if torch.is_tensor(tup):
imgs.append(tup)
elif isinstance(tup, type([])):
annos = [torch.Tensor(a) for a in tup]
#pdb.set_trace()
targets.append(torch.stack(annos, 0))
return (torch.stack(imgs, 0), targets)
评论列表
文章目录