def backward(ctx, grad_output):
def safe_zeros_backward(inp, dim):
# note that the gradient is equivalent to:
# cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
# input: [ a, b, c]
# cumprod(exclusive, normal): [1 , a, a * b]
# cumprod(exclusive, reverse): [b * c, c, 1]
# product: [b * c, a * c, a * b]
# and this is safe under input with 0s.
if inp.size(dim) == 1:
return grad_output
ones_size = torch.Size((inp.size()[:dim] + (1,) + inp.size()[dim + 1:]))
ones = Variable(grad_output.data.new(ones_size).fill_(1))
exclusive_normal_nocp = torch.cat((ones, inp.narrow(dim, 0, inp.size(dim) - 1)), dim)
exclusive_normal = exclusive_normal_nocp.cumprod(dim)
def reverse_dim(var, dim):
index = Variable(torch.arange(var.size(dim) - 1, -1, -1, out=var.data.new().long()))
return var.index_select(dim, index)
narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim)
exclusive_reverse_nocp = torch.cat((ones, narrow_reverse), dim)
exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim)
grad_input = grad_output.expand_as(exclusive_normal).mul(exclusive_normal.mul(exclusive_reverse))
return grad_input
if ctx.dim is None:
input, = ctx.saved_variables
zero_idx = (input.data == 0).nonzero()
if zero_idx.dim() == 0:
return grad_output.mul(ctx.result).expand_as(input).div(input), None, None
elif zero_idx.size(0) > 1:
return (grad_output * 0).expand_as(input), None, None
else:
return safe_zeros_backward(input.contiguous().view(-1), 0).view_as(input), None, None
else:
input, output = ctx.saved_variables
dim = ctx.dim if ctx.dim >= 0 else ctx.dim + input.dim()
if ctx.keepdim is False and len(ctx.input_size) != 1:
grad_output = grad_output.unsqueeze(dim)
output = output.unsqueeze(dim)
zero_mask = input == 0
slice_zero_count = zero_mask.sum(dim, True)
total_zeros = slice_zero_count.data.sum()
if total_zeros == 0:
grad_input = grad_output.mul(output).expand_as(input).div(input)
else:
grad_input = safe_zeros_backward(input, dim)
return grad_input, None, None
评论列表
文章目录