def __init__(self, tensor, lengths):
self.original_lengths = lengths
sorted_lengths_tensor, self.sorted_idx = torch.sort(torch.LongTensor(lengths), dim=0, descending=True)
self.tensor = tensor.index_select(dim=0, index=self.sorted_idx)
self.lengths = list(sorted_lengths_tensor)
self.original_idx = torch.LongTensor(sort_idx(self.sorted_idx))
self.mask_original = torch.zeros(*self.tensor.size())
for i, length in enumerate(self.original_lengths):
self.mask_original[i][:length].fill_(1)
评论列表
文章目录