def test_svd(self):
a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
(9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
(9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
(5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
(3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
u, s, v = torch.svd(a)
uu = torch.Tensor()
ss = torch.Tensor()
vv = torch.Tensor()
uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
self.assertEqual(u, uu, 0, 'torch.svd')
self.assertEqual(u, uuu, 0, 'torch.svd')
self.assertEqual(s, ss, 0, 'torch.svd')
self.assertEqual(s, sss, 0, 'torch.svd')
self.assertEqual(v, vv, 0, 'torch.svd')
self.assertEqual(v, vvv, 0, 'torch.svd')
# test reuse
X = torch.randn(4, 4)
U, S, V = torch.svd(X)
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
self.assertFalse(U.is_contiguous(), 'U is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
# test non-contiguous
X = torch.randn(5, 5)
U = torch.zeros(5, 2, 5)[:, 1]
S = torch.zeros(5, 2)[:, 1]
V = torch.zeros(5, 2, 5)[:, 1]
self.assertFalse(U.is_contiguous(), 'U is contiguous')
self.assertFalse(S.is_contiguous(), 'S is contiguous')
self.assertFalse(V.is_contiguous(), 'V is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
python类diag()的实例源码
def test_torch():
import torch
from torch.autograd import Variable
torch.manual_seed(0)
nx, nineq, neq = 4, 6, 7
Q = torch.randn(nx, nx)
G = torch.randn(nineq, nx)
A = torch.randn(neq, nx)
D = torch.diag(torch.rand(nineq))
K_ = torch.cat((
torch.cat((Q, torch.zeros(nx, nineq).type_as(Q), G.t(), A.t()), 1),
torch.cat((torch.zeros(nineq, nx).type_as(Q), D,
torch.eye(nineq).type_as(Q),
torch.zeros(nineq, neq).type_as(Q)), 1),
torch.cat((G, torch.eye(nineq).type_as(Q), torch.zeros(
nineq, nineq + neq).type_as(Q)), 1),
torch.cat((A, torch.zeros((neq, nineq + nineq + neq))), 1)
))
K = block((
(Q, 0, G.t(), A.t()),
(0, D, 'I', 0),
(G, 'I', 0, 0),
(A, 0, 0, 0)
))
assert (K - K_).norm() == 0.0
K = block((
(Variable(Q), 0, G.t(), Variable(A.t())),
(0, Variable(D), 'I', 0),
(Variable(G), 'I', 0, 0),
(A, 0, 0, 0)
))
assert (K.data - K_).norm() == 0.0
def test_linear_operator():
npr.seed(0)
nx, nineq, neq = 4, 6, 7
Q = npr.randn(nx, nx)
G = npr.randn(nineq, nx)
A = npr.randn(neq, nx)
D = np.diag(npr.rand(nineq))
K_ = np.bmat((
(Q, np.zeros((nx, nineq)), G.T, A.T),
(np.zeros((nineq, nx)), D, np.eye(nineq), np.zeros((nineq, neq))),
(G, np.eye(nineq), np.zeros((nineq, nineq + neq))),
(A, np.zeros((neq, nineq + nineq + neq)))
))
Q_lo = sla.aslinearoperator(Q)
G_lo = sla.aslinearoperator(G)
A_lo = sla.aslinearoperator(A)
D_lo = sla.aslinearoperator(D)
K = block((
(Q_lo, 0, G.T, A.T),
(0, D_lo, 'I', 0),
(G_lo, 'I', 0, 0),
(A_lo, 0, 0, 0)
), arrtype=sla.LinearOperator)
w1 = np.random.randn(K_.shape[1])
assert np.allclose(K_.dot(w1), K.dot(w1))
w2 = np.random.randn(K_.shape[0])
assert np.allclose(K_.T.dot(w2), K.H.dot(w2))
W = np.random.randn(*K_.shape)
assert np.allclose(K_.dot(W), K.dot(W))
def forward(ctx, input, diagonal_idx=0):
ctx.diagonal_idx = diagonal_idx
return input.diag(ctx.diagonal_idx)
def backward(ctx, grad_output):
return grad_output.diag(ctx.diagonal_idx), None
def phi(A):
"""
Return lower triangle of A and halve the diagonal.
"""
B = A.tril()
B = B - 0.5 * torch.diag(torch.diag(B))
return B
def test_eig(self):
a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00),
(-6.49, 3.80, 0.00, 0.00, 0.00),
(-0.47, -6.39, 4.17, 0.00, 0.00),
(-7.20, 1.50, -1.51, 5.70, 0.00),
(-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous()
e = torch.eig(a)[0]
ee, vv = torch.eig(a, True)
te = torch.Tensor()
tv = torch.Tensor()
eee, vvv = torch.eig(a, True, out=(te, tv))
self.assertEqual(e, ee, 1e-12)
self.assertEqual(ee, eee, 1e-12)
self.assertEqual(ee, te, 1e-12)
self.assertEqual(vv, vvv, 1e-12)
self.assertEqual(vv, tv, 1e-12)
# test reuse
X = torch.randn(4, 4)
X = torch.mm(X.t(), X)
e, v = torch.zeros(4, 2), torch.zeros(4, 4)
torch.eig(X, True, out=(e, v))
Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
self.assertFalse(v.is_contiguous(), 'V is contiguous')
torch.eig(X, True, out=(e, v))
Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t()))
self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
self.assertFalse(v.is_contiguous(), 'V is contiguous')
# test non-contiguous
X = torch.randn(4, 4)
X = torch.mm(X.t(), X)
e = torch.zeros(4, 2, 2)[:, 1]
v = torch.zeros(4, 2, 4)[:, 1]
self.assertFalse(v.is_contiguous(), 'V is contiguous')
self.assertFalse(e.is_contiguous(), 'E is contiguous')
torch.eig(X, True, out=(e, v))
Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
def test_symeig(self):
xval = torch.rand(100, 3)
cov = torch.mm(xval.t(), xval)
rese = torch.zeros(3)
resv = torch.zeros(3, 3)
# First call to symeig
self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
torch.symeig(cov.clone(), True, out=(rese, resv))
ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')
# Second call to symeig
self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
torch.symeig(cov.clone(), True, out=(rese, resv))
ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')
# test non-contiguous
X = torch.rand(5, 5)
X = X.t() * X
e = torch.zeros(4, 2).select(1, 1)
v = torch.zeros(4, 2, 4)[:, 1]
self.assertFalse(v.is_contiguous(), 'V is contiguous')
self.assertFalse(e.is_contiguous(), 'E is contiguous')
torch.symeig(X, True, out=(e, v))
Xhat = torch.mm(torch.mm(v, torch.diag(e)), v.t())
self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
def test_svd(self):
a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
(9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
(9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
(5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
(3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
u, s, v = torch.svd(a)
uu = torch.Tensor()
ss = torch.Tensor()
vv = torch.Tensor()
uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
self.assertEqual(u, uu, 0, 'torch.svd')
self.assertEqual(u, uuu, 0, 'torch.svd')
self.assertEqual(s, ss, 0, 'torch.svd')
self.assertEqual(s, sss, 0, 'torch.svd')
self.assertEqual(v, vv, 0, 'torch.svd')
self.assertEqual(v, vvv, 0, 'torch.svd')
# test reuse
X = torch.randn(4, 4)
U, S, V = torch.svd(X)
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
self.assertFalse(U.is_contiguous(), 'U is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
# test non-contiguous
X = torch.randn(5, 5)
U = torch.zeros(5, 2, 5)[:, 1]
S = torch.zeros(5, 2)[:, 1]
V = torch.zeros(5, 2, 5)[:, 1]
self.assertFalse(U.is_contiguous(), 'U is contiguous')
self.assertFalse(S.is_contiguous(), 'S is contiguous')
self.assertFalse(V.is_contiguous(), 'V is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
def pending_test_diag():
diag_actual = torch.diag(WKW)
diag_res = lazy_kronecker_product_var.diag()
assert utils.approx_equal(diag_res.data, diag_actual)
def add_diag(self, diag):
if self.added_diag is None:
return MulLazyVariable(*self.lazy_vars,
matmul_mode=self.matmul_mode,
max_iter=self.max_iter,
num_samples=self.num_samples,
added_diag=diag.expand(self.size()[0]))
else:
return MulLazyVariable(*self.lazy_vars,
matmul_mode=self.matmul_mode,
max_iter=self.max_iter,
num_samples=self.num_samples,
added_diag=self.added_diag + diag)
def diag(self):
res = Variable(torch.ones(self.size()[0]))
for lazy_var in self.lazy_vars:
res = res * lazy_var.diag()
if self.added_diag is not None:
res = res + self.added_diag
return res
def evaluate(self):
res = None
for lazy_var in self.lazy_vars:
if res is None:
res = lazy_var.evaluate()
else:
res = res * lazy_var.evaluate()
if self.added_diag is not None:
res = res + self.added_diag.diag()
return res
def factor_kkt(U_S, R, d):
""" Factor the U22 block that we can only do after we know D. """
nineq = R.size(0)
U_S[-nineq:, -nineq:] = torch.potrf(R + torch.diag(1 / d.cpu()).type_as(d))
def random_square_matrix_of_rank(l, rank):
assert rank <= l
A = torch.randn(l, l)
u, s, v = A.svd()
for i in range(l):
if i >= rank:
s[i] = 0
elif s[i] == 0:
s[i] = 1
return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
def random_fullrank_matrix_distinct_singular_value(l):
A = torch.randn(l, l)
u, _, v = A.svd()
s = torch.arange(1, l + 1).mul_(1.0 / (l + 1))
return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
def test_diag(self):
x = torch.rand(100, 100)
res1 = torch.diag(x)
res2 = torch.Tensor()
torch.diag(x, out=res2)
self.assertEqual(res1, res2)
def test_symeig(self):
xval = torch.rand(100, 3)
cov = torch.mm(xval.t(), xval)
rese = torch.zeros(3)
resv = torch.zeros(3, 3)
# First call to symeig
self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
torch.symeig(cov.clone(), True, out=(rese, resv))
ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')
# Second call to symeig
self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
torch.symeig(cov.clone(), True, out=(rese, resv))
ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')
# test non-contiguous
X = torch.rand(5, 5)
X = X.t() * X
e = torch.zeros(4, 2).select(1, 1)
v = torch.zeros(4, 2, 4)[:, 1]
self.assertFalse(v.is_contiguous(), 'V is contiguous')
self.assertFalse(e.is_contiguous(), 'E is contiguous')
torch.symeig(X, True, out=(e, v))
Xhat = torch.mm(torch.mm(v, torch.diag(e)), v.t())
self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
def test_svd(self):
a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
(9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
(9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
(5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
(3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
u, s, v = torch.svd(a)
uu = torch.Tensor()
ss = torch.Tensor()
vv = torch.Tensor()
uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
self.assertEqual(u, uu, 0, 'torch.svd')
self.assertEqual(u, uuu, 0, 'torch.svd')
self.assertEqual(s, ss, 0, 'torch.svd')
self.assertEqual(s, sss, 0, 'torch.svd')
self.assertEqual(v, vv, 0, 'torch.svd')
self.assertEqual(v, vvv, 0, 'torch.svd')
# test reuse
X = torch.randn(4, 4)
U, S, V = torch.svd(X)
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
self.assertFalse(U.is_contiguous(), 'U is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
# test non-contiguous
X = torch.randn(5, 5)
U = torch.zeros(5, 2, 5)[:, 1]
S = torch.zeros(5, 2)[:, 1]
V = torch.zeros(5, 2, 5)[:, 1]
self.assertFalse(U.is_contiguous(), 'U is contiguous')
self.assertFalse(S.is_contiguous(), 'S is contiguous')
self.assertFalse(V.is_contiguous(), 'V is contiguous')
torch.svd(X, out=(U, S, V))
Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
def cublas_dgmm(A, x, out=None):
if out is not None:
assert out.is_contiguous() and out.size() == A.size()
else:
out = A.new(A.size())
assert x.dim() == 1
assert x.numel() == A.size(-1) or x.numel() == A.size(0)
assert A.type() == x.type() == out.type()
assert A.is_contiguous()
if not isinstance(A, (torch.cuda.FloatTensor, torch.cuda.DoubleTensor)):
if x.numel() == A.size(-1):
return A.mm(torch.diag(x), out=out.view_as(A))
else:
return torch.diag(x).mm(A, out=out.view_as(A))
else:
if x.numel() == A.size(-1):
m, n = A.size(-1), A.numel() // A.size(-1)
mode = 'l'
# A.mm(x.diag(), out=out)
# return out
elif x.numel() == A.size(0):
n, m = A.size(0), A.numel() // A.size(0)
mode = 'r'
# if A.stride(0) == 1:
# mode = 'l'
# n, m = m, n
# x.diag().mm(A, out=out)
# return out
lda, ldc = m, m
incx = 1
handle = torch.cuda.current_blas_handle()
stream = torch.cuda.current_stream()._as_parameter_
from skcuda import cublas
cublas.cublasSetStream(handle, stream)
args = [handle, mode, m, n, A.data_ptr(), lda, x.data_ptr(), incx, out.data_ptr(), ldc]
if isinstance(A, torch.cuda.FloatTensor):
cublas.cublasSdgmm(*args)
elif isinstance(A, torch.cuda.DoubleTensor):
cublas.cublasDdgmm(*args)
return out