optnet-forward.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:optnet 作者: locuslab 项目源码 文件源码
def prof_instance(nz, neq, nineq, nBatch, cuda):
    L = np.tril(npr.uniform(0,1, (nz,nz))) + np.eye(nz,nz)
    G = npr.randn(nineq,nz)
    A = npr.randn(neq,nz)
    z0 = npr.randn(nz)
    s0 = np.ones(nineq)
    p = npr.randn(nBatch,nz)

    p, L, G, A, z0, s0 = [torch.Tensor(x) for x in [p, L, G, A, z0, s0]]
    Q = torch.mm(L, L.t())+0.001*torch.eye(nz).type_as(L)
    if cuda:
        p, L, Q, G, A, z0, s0 = [x.cuda() for x in [p, L, Q, G, A, z0, s0]]
    b = torch.mv(A, z0) if neq > 0 else None
    h = torch.mv(G, z0)+s0

    af = adact.AdactFunction()

    single_results = []
    start = time.time()
    U_Q, U_S, R = aip.pre_factor_kkt(Q, G, A)
    for i in range(nBatch):
        single_results.append(aip.forward_single(p[i], Q, G, A, b, h, U_Q, U_S, R))
    single_time = time.time()-start

    start = time.time()
    Q_LU, S_LU, R = aip.pre_factor_kkt_batch(Q, G, A, nBatch)
    zhat_b, nu_b, lam_b = aip.forward_batch(p, Q, G, A, b, h, Q_LU, S_LU, R)
    batched_time = time.time()-start

    zhat_diff = (single_results[0][0] - zhat_b[0]).norm()
    lam_diff = (single_results[0][2] - lam_b[0]).norm()
    eps = 0.1 # Pretty relaxed.
    if zhat_diff > eps or lam_diff > eps:
        print('===========')
        print("Warning: Single and batched solutions might not match.")
        print("  + zhat_diff: {}".format(zhat_diff))
        print("  + lam_diff: {}".format(lam_diff))
        print("  + (nz, neq, nineq, nBatch) = ({}, {}, {}, {})".format(
            nz, neq, nineq, nBatch))
        print('===========')

    return single_time, batched_time
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号