python类diag()的实例源码

test_torch.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_svd(self):
        a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
                          (9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
                          (9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
                          (5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
                          (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
        u, s, v = torch.svd(a)
        uu = torch.Tensor()
        ss = torch.Tensor()
        vv = torch.Tensor()
        uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
        self.assertEqual(u, uu, 0, 'torch.svd')
        self.assertEqual(u, uuu, 0, 'torch.svd')
        self.assertEqual(s, ss, 0, 'torch.svd')
        self.assertEqual(s, sss, 0, 'torch.svd')
        self.assertEqual(v, vv, 0, 'torch.svd')
        self.assertEqual(v, vvv, 0, 'torch.svd')

        # test reuse
        X = torch.randn(4, 4)
        U, S, V = torch.svd(X)
        Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        # test non-contiguous
        X = torch.randn(5, 5)
        U = torch.zeros(5, 2, 5)[:, 1]
        S = torch.zeros(5, 2)[:, 1]
        V = torch.zeros(5, 2, 5)[:, 1]

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        self.assertFalse(S.is_contiguous(), 'S is contiguous')
        self.assertFalse(V.is_contiguous(), 'V is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
test.py 文件源码 项目:block 作者: bamos 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_torch():
    import torch
    from torch.autograd import Variable

    torch.manual_seed(0)

    nx, nineq, neq = 4, 6, 7
    Q = torch.randn(nx, nx)
    G = torch.randn(nineq, nx)
    A = torch.randn(neq, nx)
    D = torch.diag(torch.rand(nineq))

    K_ = torch.cat((
        torch.cat((Q, torch.zeros(nx, nineq).type_as(Q), G.t(), A.t()), 1),
        torch.cat((torch.zeros(nineq, nx).type_as(Q), D,
                   torch.eye(nineq).type_as(Q),
                   torch.zeros(nineq, neq).type_as(Q)), 1),
        torch.cat((G, torch.eye(nineq).type_as(Q), torch.zeros(
            nineq, nineq + neq).type_as(Q)), 1),
        torch.cat((A, torch.zeros((neq, nineq + nineq + neq))), 1)
    ))

    K = block((
        (Q,   0, G.t(), A.t()),
        (0,   D,   'I',     0),
        (G, 'I',     0,     0),
        (A,   0,     0,     0)
    ))

    assert (K - K_).norm() == 0.0
    K = block((
        (Variable(Q),   0, G.t(), Variable(A.t())),
        (0,   Variable(D),   'I',     0),
        (Variable(G), 'I',     0,     0),
        (A,   0,     0,     0)
    ))

    assert (K.data - K_).norm() == 0.0
test.py 文件源码 项目:block 作者: bamos 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def test_linear_operator():
    npr.seed(0)

    nx, nineq, neq = 4, 6, 7
    Q = npr.randn(nx, nx)
    G = npr.randn(nineq, nx)
    A = npr.randn(neq, nx)
    D = np.diag(npr.rand(nineq))

    K_ = np.bmat((
        (Q, np.zeros((nx, nineq)), G.T, A.T),
        (np.zeros((nineq, nx)), D, np.eye(nineq), np.zeros((nineq, neq))),
        (G, np.eye(nineq), np.zeros((nineq, nineq + neq))),
        (A, np.zeros((neq, nineq + nineq + neq)))
    ))

    Q_lo = sla.aslinearoperator(Q)
    G_lo = sla.aslinearoperator(G)
    A_lo = sla.aslinearoperator(A)
    D_lo = sla.aslinearoperator(D)

    K = block((
        (Q_lo,    0,    G.T,    A.T),
        (0,    D_lo,    'I',      0),
        (G_lo,  'I',      0,      0),
        (A_lo,    0,      0,      0)
    ), arrtype=sla.LinearOperator)

    w1 = np.random.randn(K_.shape[1])
    assert np.allclose(K_.dot(w1), K.dot(w1))
    w2 = np.random.randn(K_.shape[0])
    assert np.allclose(K_.T.dot(w2), K.H.dot(w2))
    W = np.random.randn(*K_.shape)
    assert np.allclose(K_.dot(W), K.dot(W))
linalg.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def forward(ctx, input, diagonal_idx=0):
        ctx.diagonal_idx = diagonal_idx
        return input.diag(ctx.diagonal_idx)
linalg.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def backward(ctx, grad_output):
        return grad_output.diag(ctx.diagonal_idx), None
linalg.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def phi(A):
        """
        Return lower triangle of A and halve the diagonal.
        """
        B = A.tril()

        B = B - 0.5 * torch.diag(torch.diag(B))

        return B
test_torch.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def test_eig(self):
        a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00),
                          (-6.49, 3.80, 0.00, 0.00, 0.00),
                          (-0.47, -6.39, 4.17, 0.00, 0.00),
                          (-7.20, 1.50, -1.51, 5.70, 0.00),
                          (-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous()
        e = torch.eig(a)[0]
        ee, vv = torch.eig(a, True)
        te = torch.Tensor()
        tv = torch.Tensor()
        eee, vvv = torch.eig(a, True, out=(te, tv))
        self.assertEqual(e, ee, 1e-12)
        self.assertEqual(ee, eee, 1e-12)
        self.assertEqual(ee, te, 1e-12)
        self.assertEqual(vv, vvv, 1e-12)
        self.assertEqual(vv, tv, 1e-12)

        # test reuse
        X = torch.randn(4, 4)
        X = torch.mm(X.t(), X)
        e, v = torch.zeros(4, 2), torch.zeros(4, 4)
        torch.eig(X, True, out=(e, v))
        Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
        self.assertFalse(v.is_contiguous(), 'V is contiguous')

        torch.eig(X, True, out=(e, v))
        Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t()))
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
        self.assertFalse(v.is_contiguous(), 'V is contiguous')

        # test non-contiguous
        X = torch.randn(4, 4)
        X = torch.mm(X.t(), X)
        e = torch.zeros(4, 2, 2)[:, 1]
        v = torch.zeros(4, 2, 4)[:, 1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.eig(X, True, out=(e, v))
        Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
test_torch.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def test_symeig(self):
        xval = torch.rand(100, 3)
        cov = torch.mm(xval.t(), xval)
        rese = torch.zeros(3)
        resv = torch.zeros(3, 3)

        # First call to symeig
        self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # Second call to symeig
        self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # test non-contiguous
        X = torch.rand(5, 5)
        X = X.t() * X
        e = torch.zeros(4, 2).select(1, 1)
        v = torch.zeros(4, 2, 4)[:, 1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.symeig(X, True, out=(e, v))
        Xhat = torch.mm(torch.mm(v, torch.diag(e)), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
test_torch.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def test_svd(self):
        a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
                          (9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
                          (9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
                          (5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
                          (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
        u, s, v = torch.svd(a)
        uu = torch.Tensor()
        ss = torch.Tensor()
        vv = torch.Tensor()
        uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
        self.assertEqual(u, uu, 0, 'torch.svd')
        self.assertEqual(u, uuu, 0, 'torch.svd')
        self.assertEqual(s, ss, 0, 'torch.svd')
        self.assertEqual(s, sss, 0, 'torch.svd')
        self.assertEqual(v, vv, 0, 'torch.svd')
        self.assertEqual(v, vvv, 0, 'torch.svd')

        # test reuse
        X = torch.randn(4, 4)
        U, S, V = torch.svd(X)
        Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        # test non-contiguous
        X = torch.randn(5, 5)
        U = torch.zeros(5, 2, 5)[:, 1]
        S = torch.zeros(5, 2)[:, 1]
        V = torch.zeros(5, 2, 5)[:, 1]

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        self.assertFalse(S.is_contiguous(), 'S is contiguous')
        self.assertFalse(V.is_contiguous(), 'V is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
kronecker_product_lazy_variable_test.py 文件源码 项目:gpytorch 作者: jrg365 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def pending_test_diag():
    diag_actual = torch.diag(WKW)
    diag_res = lazy_kronecker_product_var.diag()
    assert utils.approx_equal(diag_res.data, diag_actual)
mul_lazy_variable.py 文件源码 项目:gpytorch 作者: jrg365 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def add_diag(self, diag):
        if self.added_diag is None:
            return MulLazyVariable(*self.lazy_vars,
                                   matmul_mode=self.matmul_mode,
                                   max_iter=self.max_iter,
                                   num_samples=self.num_samples,
                                   added_diag=diag.expand(self.size()[0]))
        else:
            return MulLazyVariable(*self.lazy_vars,
                                   matmul_mode=self.matmul_mode,
                                   max_iter=self.max_iter,
                                   num_samples=self.num_samples,
                                   added_diag=self.added_diag + diag)
mul_lazy_variable.py 文件源码 项目:gpytorch 作者: jrg365 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def diag(self):
        res = Variable(torch.ones(self.size()[0]))
        for lazy_var in self.lazy_vars:
            res = res * lazy_var.diag()

        if self.added_diag is not None:
            res = res + self.added_diag
        return res
mul_lazy_variable.py 文件源码 项目:gpytorch 作者: jrg365 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def evaluate(self):
        res = None
        for lazy_var in self.lazy_vars:
            if res is None:
                res = lazy_var.evaluate()
            else:
                res = res * lazy_var.evaluate()

        if self.added_diag is not None:
            res = res + self.added_diag.diag()
        return res
single.py 文件源码 项目:qpth 作者: locuslab 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def factor_kkt(U_S, R, d):
    """ Factor the U22 block that we can only do after we know D. """
    nineq = R.size(0)
    U_S[-nineq:, -nineq:] = torch.potrf(R + torch.diag(1 / d.cpu()).type_as(d))
test_autograd.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def random_square_matrix_of_rank(l, rank):
    assert rank <= l
    A = torch.randn(l, l)
    u, s, v = A.svd()
    for i in range(l):
        if i >= rank:
            s[i] = 0
        elif s[i] == 0:
            s[i] = 1
    return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
test_autograd.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def random_fullrank_matrix_distinct_singular_value(l):
    A = torch.randn(l, l)
    u, _, v = A.svd()
    s = torch.arange(1, l + 1).mul_(1.0 / (l + 1))
    return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def test_diag(self):
        x = torch.rand(100, 100)
        res1 = torch.diag(x)
        res2 = torch.Tensor()
        torch.diag(x, out=res2)
        self.assertEqual(res1, res2)
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_symeig(self):
        xval = torch.rand(100, 3)
        cov = torch.mm(xval.t(), xval)
        rese = torch.zeros(3)
        resv = torch.zeros(3, 3)

        # First call to symeig
        self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # Second call to symeig
        self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # test non-contiguous
        X = torch.rand(5, 5)
        X = X.t() * X
        e = torch.zeros(4, 2).select(1, 1)
        v = torch.zeros(4, 2, 4)[:, 1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.symeig(X, True, out=(e, v))
        Xhat = torch.mm(torch.mm(v, torch.diag(e)), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
test_torch.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_svd(self):
        a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
                          (9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
                          (9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
                          (5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
                          (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
        u, s, v = torch.svd(a)
        uu = torch.Tensor()
        ss = torch.Tensor()
        vv = torch.Tensor()
        uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
        self.assertEqual(u, uu, 0, 'torch.svd')
        self.assertEqual(u, uuu, 0, 'torch.svd')
        self.assertEqual(s, ss, 0, 'torch.svd')
        self.assertEqual(s, sss, 0, 'torch.svd')
        self.assertEqual(v, vv, 0, 'torch.svd')
        self.assertEqual(v, vvv, 0, 'torch.svd')

        # test reuse
        X = torch.randn(4, 4)
        U, S, V = torch.svd(X)
        Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        # test non-contiguous
        X = torch.randn(5, 5)
        U = torch.zeros(5, 2, 5)[:, 1]
        S = torch.zeros(5, 2)[:, 1]
        V = torch.zeros(5, 2, 5)[:, 1]

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        self.assertFalse(S.is_contiguous(), 'S is contiguous')
        self.assertFalse(V.is_contiguous(), 'V is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
dgmm.py 文件源码 项目:pyinn 作者: szagoruyko 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def cublas_dgmm(A, x, out=None):
    if out is not None:
        assert out.is_contiguous() and out.size() == A.size()
    else:
        out = A.new(A.size())
    assert x.dim() == 1
    assert x.numel() == A.size(-1) or x.numel() == A.size(0)
    assert A.type() == x.type() == out.type()
    assert A.is_contiguous()

    if not isinstance(A, (torch.cuda.FloatTensor, torch.cuda.DoubleTensor)):
        if x.numel() == A.size(-1):
            return A.mm(torch.diag(x), out=out.view_as(A))
        else:
            return torch.diag(x).mm(A, out=out.view_as(A))
    else:
        if x.numel() == A.size(-1):
            m, n =  A.size(-1), A.numel() // A.size(-1)
            mode = 'l'
            # A.mm(x.diag(), out=out)
            # return out
        elif x.numel() == A.size(0):
            n, m = A.size(0), A.numel() // A.size(0)
            mode = 'r'
            # if A.stride(0) == 1:
            #     mode = 'l'
            #     n, m = m, n
            # x.diag().mm(A, out=out)
            # return out
        lda, ldc = m, m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        from skcuda import cublas
        cublas.cublasSetStream(handle, stream)
        args = [handle, mode, m, n, A.data_ptr(), lda, x.data_ptr(), incx, out.data_ptr(), ldc]
        if isinstance(A, torch.cuda.FloatTensor):
            cublas.cublasSdgmm(*args)
        elif isinstance(A, torch.cuda.DoubleTensor):
            cublas.cublasDdgmm(*args)
        return out


问题


面经


文章

微信
公众号

扫码关注公众号