def backward(ctx, grad_output): a, b = ctx.saved_variables grad_a = grad_output.mul(b).mul(a.pow(b - 1)) grad_b = grad_output.mul(a.pow(b)).mul(a.log()) return grad_a, maybe_view(grad_b, ctx.b_size)