test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:qpth 作者: locuslab 项目源码 文件源码
def test_dl_dA():
    nz, neq, nineq = 10, 3, 1
    [p, Q, G, h, A, b, truez], [dQ, dp, dG, dh, dA, db] = get_grads(
        nz=nz, neq=neq, nineq=nineq, Qscale=100., Gscale=100., Ascale=100.)

    def f(A):
        A = A.reshape(neq, nz)
        _, zhat, nu, lam, slacks = qp_cvxpy.forward_single_np(Q, p, G, h, A, b)
        return 0.5 * np.sum(np.square(zhat - truez))

    df = nd.Gradient(f)
    dA_fd = df(A.ravel()).reshape(neq, nz)
    if verbose:
        # print('dA_fd[0,:]: ', dA_fd[0,:])
        # print('dA[0,:]: ', dA[0,:])
        print('dA_fd: ', dA_fd)
        print('dA: ', dA)
    npt.assert_allclose(dA_fd, dA, rtol=RTOL, atol=ATOL)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号