optnet-np.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:optnet 作者: locuslab 项目源码 文件源码
def test_ip_forward():
    p_t, Q_t, G_t, A_t, z0_t, s0_t = [torch.Tensor(x) for x in [p, Q, G, A, z0, s0]]
    b = torch.mv(A_t, z0_t) if neq > 0 else None
    h = torch.mv(G_t,z0_t)+s0_t
    L_Q, L_S, R = aip.pre_factor_kkt(Q_t, G_t, A_t)

    zhat_ip, nu_ip, lam_ip = aip.forward_single(p_t, Q_t, G_t, A_t, b, h, L_Q, L_S, R)
    # Unnecessary clones here because of a pytorch bug when calling numpy
    # on a tensor with a non-zero offset.
    npt.assert_allclose(zhat, zhat_ip.clone().numpy(), rtol=RTOL, atol=ATOL)
    if neq > 0:
        npt.assert_allclose(nu, nu_ip.clone().numpy(), rtol=RTOL, atol=ATOL)
    npt.assert_allclose(lam, lam_ip.clone().numpy(), rtol=RTOL, atol=ATOL)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号