shrinkage.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:onsager_deep_learning 作者: mborgerding 项目源码 文件源码
def shrink_piecwise_linear(r,rvar,theta):
    """Implement the piecewise linear shrinkage function.
        With minor modifications and variance normalization.
        theta[...,0] : abscissa of first vertex, scaled by sqrt(rvar)
        theta[...,1] : abscissa of second vertex, scaled by sqrt(rvar)
        theta[...,2] : slope from origin to first vertex
        theta[''',3] : slope from first vertex to second vertex
        theta[...,4] : slope after second vertex
    """
    ab0 = theta[...,0]
    ab1 = theta[...,1]
    sl0 = theta[...,2]
    sl1 = theta[...,3]
    sl2 = theta[...,4]

    # scale each column by sqrt(rvar)
    scale_out = tf.sqrt(rvar)
    scale_in = 1/scale_out
    rs = tf.sign(r*scale_in)
    ra = tf.abs(r*scale_in)

    # split the piecewise linear function into regions
    rgn0 = tf.to_float( ra<ab0)
    rgn1 = tf.to_float( ra<ab1) - rgn0
    rgn2 = tf.to_float( ra>=ab1)
    xhat = scale_out * rs*(
            rgn0*sl0*ra +
            rgn1*(sl1*(ra - ab0) + sl0*ab0 ) +
            rgn2*(sl2*(ra - ab1) +  sl0*ab0 + sl1*(ab1-ab0) )
            )
    dxdr =  sl0*rgn0 + sl1*rgn1 + sl2*rgn2
    dxdr = tf.reduce_mean(dxdr,0)
    return (xhat,dxdr)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号