def sort_batch(data, seq_len):
sorted_seq_len, sorted_idx = torch.sort(seq_len, dim=0, descending=True)
sorted_data = data[sorted_idx.data]
_, reverse_idx = torch.sort(sorted_idx, dim=0, descending=False)
return sorted_data, sorted_seq_len.cuda(), reverse_idx.cuda()
评论列表
文章目录