def pending_test_diag(): diag_actual = torch.diag(WKW) diag_res = lazy_kronecker_product_var.diag() assert utils.approx_equal(diag_res.data, diag_actual)