def sum_scan_exclusive(x, dim): ret = torch.cumsum(-x, dim=dim) end_idx = ret.size(dim) - 1 ret_sum = ret.narrow(dim, end_idx, 1).clone() ret -= ret_sum.expand_as(ret) ret += x return ret