def test_symeig(self):
# Small case
tensor = torch.randn(3, 3).cuda()
tensor = torch.mm(tensor, tensor.t())
eigval, eigvec = torch.symeig(tensor, eigenvectors=True)
self.assertEqual(tensor, torch.mm(torch.mm(eigvec, eigval.diag()), eigvec.t()))
# Large case
tensor = torch.randn(257, 257).cuda()
tensor = torch.mm(tensor, tensor.t())
eigval, eigvec = torch.symeig(tensor, eigenvectors=True)
self.assertEqual(tensor, torch.mm(torch.mm(eigvec, eigval.diag()), eigvec.t()))
评论列表
文章目录