def forward(ctx, input, dim): ctx.dim = dim ctx.save_for_backward(input) return torch.cumprod(input, dim=ctx.dim)