util.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:MGP-RNN 作者: jfutoma 项目源码 文件源码
def CG(A,b):
    """ Conjugate gradient, to get solution x = A^-1 * b,
    can be faster than using the Cholesky for large scale problems
    """
    b = tf.reshape(b,[-1])
    n = tf.shape(A)[0]
    x = tf.zeros([n]) 
    r_ = b 
    p = r_ 

    #These settings are somewhat arbitrary
    #You might want to test sensitivity to these
    CG_EPS = tf.cast(n/1000,"float")
    MAX_ITER = tf.div(n,250) + 3

    def cond(i,x,r,p):
        return tf.logical_and(i < MAX_ITER, tf.norm(r) > CG_EPS)

    def body(i,x,r_,p):        
        p_vec = tf.reshape(p,[-1,1])
        Ap = tf.reshape(tf.matmul(A,p_vec),[-1]) #make a vector

        alpha = dot(r_,r_)/dot(p,Ap)
        x = x + alpha*p
        r = r_ - alpha*Ap
        beta = dot(r,r)/dot(r_,r_)
        p = r + beta*p

        return i+1,x,r,p

    i = tf.constant(0)
    i,x,r,p = tf.while_loop(cond,body,loop_vars=[i,x,r_,p])

    return tf.reshape(x,[-1,1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号