def test_trace_components_normal_matrices(): a_mat = torch.randn(3, 4) b_mat = torch.randn(3, 4) a_res, b_res = trace_components(a_mat, b_mat) assert torch.equal(a_res, a_mat) assert torch.equal(b_res, b_mat)