def test_add_diag(): lazy_var = make_mul_lazy_var()[0] assert torch.equal(lazy_var.evaluate().data, (t1_t2_t3_eval + added_diag.diag()))