def test_symeig(self):
xval = torch.rand(100,3)
cov = torch.mm(xval.t(), xval)
rese = torch.zeros(3)
resv = torch.zeros(3,3)
# First call to symeig
self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
torch.symeig(rese, resv, cov.clone(), True)
ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')
# Second call to symeig
self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
torch.symeig(rese, resv, cov.clone(), True)
ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')
# test non-contiguous
X = torch.rand(5, 5)
X = X.t() * X
e = torch.zeros(4, 2).select(1, 1)
v = torch.zeros(4, 2, 4)[:,1]
self.assertFalse(v.is_contiguous(), 'V is contiguous')
self.assertFalse(e.is_contiguous(), 'E is contiguous')
torch.symeig(e, v, X, True)
Xhat = torch.mm(torch.mm(v, torch.diag(e)), v.t())
self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
评论列表
文章目录