shrinkage.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:onsager_deep_learning 作者: mborgerding 项目源码 文件源码
def shrink_soft_threshold(r,rvar,theta):
    """
    soft threshold function
        y=sign(x)*max(0,abs(x)-theta[0]*sqrt(rvar) )*scaling
    where scaling is theta[1] (default=1)
    in other words, if theta is len(1), then the standard
    """
    if len(theta.get_shape())>0 and theta.get_shape() != (1,):
        lam = theta[0] * tf.sqrt(rvar)
        scale=theta[1]
    else:
        lam  = theta * tf.sqrt(rvar)
        scale = None
    lam = tf.maximum(lam,0)
    arml = tf.abs(r) - lam
    xhat = tf.sign(r) * tf.maximum(arml,0)
    dxdr = tf.reduce_mean(tf.to_float(arml>0),0)
    if scale is not None:
        xhat = xhat*scale
        dxdr = dxdr*scale
    return (xhat,dxdr)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号