test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:block 作者: bamos 项目源码 文件源码
def test_torch():
    import torch
    from torch.autograd import Variable

    torch.manual_seed(0)

    nx, nineq, neq = 4, 6, 7
    Q = torch.randn(nx, nx)
    G = torch.randn(nineq, nx)
    A = torch.randn(neq, nx)
    D = torch.diag(torch.rand(nineq))

    K_ = torch.cat((
        torch.cat((Q, torch.zeros(nx, nineq).type_as(Q), G.t(), A.t()), 1),
        torch.cat((torch.zeros(nineq, nx).type_as(Q), D,
                   torch.eye(nineq).type_as(Q),
                   torch.zeros(nineq, neq).type_as(Q)), 1),
        torch.cat((G, torch.eye(nineq).type_as(Q), torch.zeros(
            nineq, nineq + neq).type_as(Q)), 1),
        torch.cat((A, torch.zeros((neq, nineq + nineq + neq))), 1)
    ))

    K = block((
        (Q,   0, G.t(), A.t()),
        (0,   D,   'I',     0),
        (G, 'I',     0,     0),
        (A,   0,     0,     0)
    ))

    assert (K - K_).norm() == 0.0
    K = block((
        (Variable(Q),   0, G.t(), Variable(A.t())),
        (0,   Variable(D),   'I',     0),
        (Variable(G), 'I',     0,     0),
        (A,   0,     0,     0)
    ))

    assert (K.data - K_).norm() == 0.0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号