tensor.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def backward(ctx, grad_output):
        def safe_zeros_backward(inp, dim):
            # note that the gradient is equivalent to:
            # cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
            # input:                        [    a,     b,     c]
            # cumprod(exclusive, normal):   [1    ,     a, a * b]
            # cumprod(exclusive, reverse):  [b * c,     c,     1]
            # product:                      [b * c, a * c, a * b]
            # and this is safe under input with 0s.
            if inp.size(dim) == 1:
                return grad_output

            ones_size = torch.Size((inp.size()[:dim] + (1,) + inp.size()[dim + 1:]))
            ones = Variable(grad_output.data.new(ones_size).fill_(1))
            exclusive_normal_nocp = torch.cat((ones, inp.narrow(dim, 0, inp.size(dim) - 1)), dim)
            exclusive_normal = exclusive_normal_nocp.cumprod(dim)

            def reverse_dim(var, dim):
                index = Variable(torch.arange(var.size(dim) - 1, -1, -1, out=var.data.new().long()))
                return var.index_select(dim, index)

            narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim)
            exclusive_reverse_nocp = torch.cat((ones, narrow_reverse), dim)
            exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim)

            grad_input = grad_output.expand_as(exclusive_normal).mul(exclusive_normal.mul(exclusive_reverse))
            return grad_input

        if ctx.dim is None:
            input, = ctx.saved_variables
            zero_idx = (input.data == 0).nonzero()
            if zero_idx.dim() == 0:
                return grad_output.mul(ctx.result).expand_as(input).div(input), None, None
            elif zero_idx.size(0) > 1:
                return (grad_output * 0).expand_as(input), None, None
            else:
                return safe_zeros_backward(input.contiguous().view(-1), 0).view_as(input), None, None

        else:
            input, output = ctx.saved_variables
            dim = ctx.dim if ctx.dim >= 0 else ctx.dim + input.dim()
            if ctx.keepdim is False and len(ctx.input_size) != 1:
                grad_output = grad_output.unsqueeze(dim)
                output = output.unsqueeze(dim)

            zero_mask = input == 0
            slice_zero_count = zero_mask.sum(dim, True)
            total_zeros = slice_zero_count.data.sum()
            if total_zeros == 0:
                grad_input = grad_output.mul(output).expand_as(input).div(input)
            else:
                grad_input = safe_zeros_backward(input, dim)

            return grad_input, None, None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号