def test_inv_matmul(): mat = torch.randn(4, 4) res = make_mul_lazy_var()[0].inv_matmul(Variable(mat)) assert torch.norm(res.data - (t1_t2_t3_eval + added_diag.diag()).inverse().matmul(mat)) < 1e-3