def train_collate_fn(data):
'''
??????????????minibatch???
'''
# ??video??????????
data.sort(key=lambda x: x[-1], reverse=True)
videos, captions, lengths, video_ids = zip(*data)
# ??????????2D Tensor?????3D Tensor?
videos = torch.stack(videos, 0)
# ?caption???????1D Tensor???????2D Tensor?
captions = torch.stack(captions, 0)
return videos, captions, lengths, video_ids
评论列表
文章目录