dataset.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:R-net 作者: matthew-z 项目源码 文件源码
def __init__(self, tensor, lengths):
        self.original_lengths = lengths
        sorted_lengths_tensor, self.sorted_idx = torch.sort(torch.LongTensor(lengths), dim=0, descending=True)

        self.tensor = tensor.index_select(dim=0, index=self.sorted_idx)

        self.lengths = list(sorted_lengths_tensor)
        self.original_idx = torch.LongTensor(sort_idx(self.sorted_idx))

        self.mask_original = torch.zeros(*self.tensor.size())
        for i, length in enumerate(self.original_lengths):
            self.mask_original[i][:length].fill_(1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号