solveCrossTime.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:TICC 作者: davidhallac 项目源码 文件源码
def ADMM_x(entry):
    global rho
    variables = entry[X_VARS]

    #-----------------------Proximal operator ---------------------------
    x_update = [] # proximal update for the variable x
    if(__builtin__.len(entry[1].args) > 1 ):
        # print 'we are in logdet + trace node'
        cvxpyMat = entry[1].args[1].args[0].args[0]
        numpymat = cvxpyMat.value

        mat_shape = ( int( numpymat.shape[1] *  ( numpymat.shape[1]+1 )/2.0 ) ,)
        a = numpy.zeros(mat_shape) 

        for i in xrange(entry[X_DEG]):  
            z_index = X_NEIGHBORS + (2 * i)
            u_index = z_index + 1
            zi = entry[z_index]
            ui = entry[u_index]

            for (varID, varName, var, offset) in variables:

                z = getValue(edge_z_vals, zi + offset, var.size[0])
                u = getValue(edge_u_vals, ui + offset, var.size[0])
                a += (z-u) 
        A = upper2Full(a)
        A =  A/entry[X_DEG]
        eta = 1/float(rho)

        x_update = Prox_logdet(numpymat, A, eta)
        solution = numpy.array(x_update).T.reshape(-1)

        writeValue(node_vals, entry[X_IND] + variables[0][3], solution, variables[0][2].size[0]) 
    else:
        x_update = [] # no variable to update for dummy node
    return None

# z-update for ADMM for one edge
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号