def backward(ctx, grad_output): inverse, = ctx.saved_variables return -torch.mm(inverse.t(), torch.mm(grad_output, inverse.t()))