AoAReader.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:AoAReader 作者: kevinkwl 项目源码 文件源码
def sort_batch(data, seq_len):
    sorted_seq_len, sorted_idx = torch.sort(seq_len, dim=0, descending=True)
    sorted_data = data[sorted_idx.data]
    _, reverse_idx = torch.sort(sorted_idx, dim=0, descending=False)
    return sorted_data, sorted_seq_len.cuda(), reverse_idx.cuda()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号