def test_ip_forward():
p_t, Q_t, G_t, A_t, z0_t, s0_t = [torch.Tensor(x) for x in [p, Q, G, A, z0, s0]]
b = torch.mv(A_t, z0_t) if neq > 0 else None
h = torch.mv(G_t,z0_t)+s0_t
L_Q, L_S, R = aip.pre_factor_kkt(Q_t, G_t, A_t)
zhat_ip, nu_ip, lam_ip = aip.forward_single(p_t, Q_t, G_t, A_t, b, h, L_Q, L_S, R)
# Unnecessary clones here because of a pytorch bug when calling numpy
# on a tensor with a non-zero offset.
npt.assert_allclose(zhat, zhat_ip.clone().numpy(), rtol=RTOL, atol=ATOL)
if neq > 0:
npt.assert_allclose(nu, nu_ip.clone().numpy(), rtol=RTOL, atol=ATOL)
npt.assert_allclose(lam, lam_ip.clone().numpy(), rtol=RTOL, atol=ATOL)
评论列表
文章目录