def backward(self, grad_output):
grad_input = torch.cumsum(-grad_output, dim=self.dim)
end_idx = grad_input.size(self.dim) - 1
grad_sum = grad_input.narrow(self.dim, end_idx, 1)
grad_input -= grad_sum.expand_as(grad_input)
grad_input += grad_output
return grad_input
# TODO: unfold
评论列表
文章目录