segment.py 文件源码

python
阅读 46 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def backward(ctx, grad_outputs):
        size = grad_outputs.size(1)
        segm_sorted = torch.sort(ctx.rev_segm_sorted)[1]
        grad_outputs = torch.index_select(grad_outputs, 0, segm_sorted)

        offset = [ctx.num_zeros]

        def backward_segment(l, n):
            segment_grad = grad_outputs.narrow(0, offset[0], n // l)
            if l > 1:
                segment_grad = _MyMax.backward(ctx.maxes[l], segment_grad)[0].view(n, size)
            offset[0] += n // l
            return segment_grad

        segment_grads = [backward_segment(l, n) for l, n in enumerate(ctx.num_lengths) if n > 0]
        grads = torch.cat(segment_grads, 0)
        rev_length_sorted = torch.sort(ctx.lengths_sorted)[1]
        grads = torch.index_select(grads, 0, rev_length_sorted)

        return grads, None, None, None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号