test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:qpth 作者: locuslab 项目源码 文件源码
def get_grads(nBatch=1, nz=10, neq=1, nineq=3, Qscale=1.,
              Gscale=1., hscale=1., Ascale=1., bscale=1.):
    assert(nBatch == 1)
    npr.seed(1)
    L = np.random.randn(nz, nz)
    Q = Qscale * L.dot(L.T)
    G = Gscale * npr.randn(nineq, nz)
    # h = hscale*npr.randn(nineq)
    z0 = npr.randn(nz)
    s0 = npr.rand(nineq)
    h = G.dot(z0) + s0
    A = Ascale * npr.randn(neq, nz)
    # b = bscale*npr.randn(neq)
    b = A.dot(z0)

    p = npr.randn(nBatch, nz)
    # print(np.linalg.norm(p))
    truez = npr.randn(nBatch, nz)

    Q, p, G, h, A, b, truez = [x.astype(np.float64) for x in
                               [Q, p, G, h, A, b, truez]]
    _, zhat, nu, lam, slacks = qp_cvxpy.forward_single_np(Q, p[0], G, h, A, b)

    grads = get_grads_torch(Q, p, G, h, A, b, truez)
    return [p[0], Q, G, h, A, b, truez], grads
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号