def test_svd(self):
a=torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
(9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
(9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
(5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
(3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
u, s, v = torch.svd(a)
uu = torch.Tensor()
ss = torch.Tensor()
vv = torch.Tensor()
uuu, sss, vvv = torch.svd(uu, ss, vv, a)
self.assertEqual(u, uu, 0, 'torch.svd')
self.assertEqual(u, uuu, 0, 'torch.svd')
self.assertEqual(s, ss, 0, 'torch.svd')
self.assertEqual(s, sss, 0, 'torch.svd')
self.assertEqual(v, vv, 0, 'torch.svd')
self.assertEqual(v, vvv, 0, 'torch.svd')
# test reuse
X = torch.randn(4, 4)
U, S, V = torch.svd(X)
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
self.assertFalse(U.is_contiguous(), 'U is contiguous')
torch.svd(U, S, V, X)
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
# test non-contiguous
X = torch.randn(5, 5)
U = torch.zeros(5, 2, 5)[:,1]
S = torch.zeros(5, 2)[:,1]
V = torch.zeros(5, 2, 5)[:,1]
self.assertFalse(U.is_contiguous(), 'U is contiguous')
self.assertFalse(S.is_contiguous(), 'S is contiguous')
self.assertFalse(V.is_contiguous(), 'V is contiguous')
torch.svd(U, S, V, X)
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
python类svd()的实例源码
def test_svd(self):
a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
(9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
(9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
(5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
(3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
u, s, v = torch.svd(a)
uu = torch.Tensor()
ss = torch.Tensor()
vv = torch.Tensor()
uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
self.assertEqual(u, uu, 0, 'torch.svd')
self.assertEqual(u, uuu, 0, 'torch.svd')
self.assertEqual(s, ss, 0, 'torch.svd')
self.assertEqual(s, sss, 0, 'torch.svd')
self.assertEqual(v, vv, 0, 'torch.svd')
self.assertEqual(v, vvv, 0, 'torch.svd')
# test reuse
X = torch.randn(4, 4)
U, S, V = torch.svd(X)
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
self.assertFalse(U.is_contiguous(), 'U is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
# test non-contiguous
X = torch.randn(5, 5)
U = torch.zeros(5, 2, 5)[:, 1]
S = torch.zeros(5, 2)[:, 1]
V = torch.zeros(5, 2, 5)[:, 1]
self.assertFalse(U.is_contiguous(), 'U is contiguous')
self.assertFalse(S.is_contiguous(), 'S is contiguous')
self.assertFalse(V.is_contiguous(), 'V is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
def test_svd(self):
a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
(9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
(9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
(5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
(3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
u, s, v = torch.svd(a)
uu = torch.Tensor()
ss = torch.Tensor()
vv = torch.Tensor()
uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
self.assertEqual(u, uu, 0, 'torch.svd')
self.assertEqual(u, uuu, 0, 'torch.svd')
self.assertEqual(s, ss, 0, 'torch.svd')
self.assertEqual(s, sss, 0, 'torch.svd')
self.assertEqual(v, vv, 0, 'torch.svd')
self.assertEqual(v, vvv, 0, 'torch.svd')
# test reuse
X = torch.randn(4, 4)
U, S, V = torch.svd(X)
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
self.assertFalse(U.is_contiguous(), 'U is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
# test non-contiguous
X = torch.randn(5, 5)
U = torch.zeros(5, 2, 5)[:, 1]
S = torch.zeros(5, 2)[:, 1]
V = torch.zeros(5, 2, 5)[:, 1]
self.assertFalse(U.is_contiguous(), 'U is contiguous')
self.assertFalse(S.is_contiguous(), 'S is contiguous')
self.assertFalse(V.is_contiguous(), 'V is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
def test_svd(self):
a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
(9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
(9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
(5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
(3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
u, s, v = torch.svd(a)
uu = torch.Tensor()
ss = torch.Tensor()
vv = torch.Tensor()
uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
self.assertEqual(u, uu, 0, 'torch.svd')
self.assertEqual(u, uuu, 0, 'torch.svd')
self.assertEqual(s, ss, 0, 'torch.svd')
self.assertEqual(s, sss, 0, 'torch.svd')
self.assertEqual(v, vv, 0, 'torch.svd')
self.assertEqual(v, vvv, 0, 'torch.svd')
# test reuse
X = torch.randn(4, 4)
U, S, V = torch.svd(X)
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
self.assertFalse(U.is_contiguous(), 'U is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
# test non-contiguous
X = torch.randn(5, 5)
U = torch.zeros(5, 2, 5)[:, 1]
S = torch.zeros(5, 2)[:, 1]
V = torch.zeros(5, 2, 5)[:, 1]
self.assertFalse(U.is_contiguous(), 'U is contiguous')
self.assertFalse(S.is_contiguous(), 'S is contiguous')
self.assertFalse(V.is_contiguous(), 'V is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
def func(): linalg.svd(np_data)
def func(): linalg.svd(np_data, lapack_driver='gesvd');
def func(): linalg.svd(np_data, lapack_driver='gesdd');
def func(): torch.svd(torch.rand((N,N)))
def func(): torch.svd(torch.rand((N,N)).cuda())
def partial_svd(matrix, n_eigenvecs=None):
"""Computes a fast partial SVD on `matrix`
if `n_eigenvecs` is specified, sparse eigendecomposition
is used on either matrix.dot(matrix.T) or matrix.T.dot(matrix)
Parameters
----------
matrix : 2D-array
n_eigenvecs : int, optional, default is None
if specified, number of eigen[vectors-values] to return
Returns
-------
U : 2D-array
of shape (matrix.shape[0], n_eigenvecs)
contains the right singular vectors
S : 1D-array
of shape (n_eigenvecs, )
contains the singular values of `matrix`
V : 2D-array
of shape (n_eigenvecs, matrix.shape[1])
contains the left singular vectors
"""
# Check that matrix is... a matrix!
if ndim(matrix) != 2:
raise ValueError('matrix be a matrix. matrix.ndim is {} != 2'.format(
ndim(matrix)))
U, S, V = torch.svd(matrix, some=False)
U, S, V = U[:, :n_eigenvecs], S[:n_eigenvecs], V.t()[:n_eigenvecs, :]
return U, S, V
def det(var):
r"""Calculates determinant of a 2D square Variable.
.. note::
Backward through `det` internally uses SVD results. So double backward
through `det` will need to backward through :meth:`~Tensor.svd`. This
can be unstable in certain cases. Please see :meth:`~torch.svd` for
details.
Arguments:
var (Variable): The input 2D square Variable.
"""
if torch.is_tensor(var):
raise ValueError("det is currently only supported on Variable")
return var.det()
def test_svd(self):
a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
(9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
(9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
(5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
(3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
u, s, v = torch.svd(a)
uu = torch.Tensor()
ss = torch.Tensor()
vv = torch.Tensor()
uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
self.assertEqual(u, uu, 0, 'torch.svd')
self.assertEqual(u, uuu, 0, 'torch.svd')
self.assertEqual(s, ss, 0, 'torch.svd')
self.assertEqual(s, sss, 0, 'torch.svd')
self.assertEqual(v, vv, 0, 'torch.svd')
self.assertEqual(v, vvv, 0, 'torch.svd')
# test reuse
X = torch.randn(4, 4)
U, S, V = torch.svd(X)
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
self.assertFalse(U.is_contiguous(), 'U is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
# test non-contiguous
X = torch.randn(5, 5)
U = torch.zeros(5, 2, 5)[:, 1]
S = torch.zeros(5, 2)[:, 1]
V = torch.zeros(5, 2, 5)[:, 1]
self.assertFalse(U.is_contiguous(), 'U is contiguous')
self.assertFalse(S.is_contiguous(), 'S is contiguous')
self.assertFalse(V.is_contiguous(), 'V is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')