shrinkage.py 文件源码

python
阅读 44 收藏 0 点赞 0 评论 0

项目:onsager_deep_learning 作者: mborgerding 项目源码 文件源码
def shrink_spline(r,rvar,theta):
    """ Spline-based shrinkage function
    """
    scale = theta[0]*tf.sqrt(rvar)
    rs = tf.sign(r)
    ar = tf.abs(r/scale)
    ar2 = tf.square(ar)
    ar3 = ar*ar2
    reg1 = tf.to_float(ar<1)
    reg2 = tf.to_float(ar<2)-reg1
    ar_m2 = 2-ar
    ar_m2_p2 = tf.square(ar_m2)
    ar_m2_p3 = ar_m2*ar_m2_p2
    beta3 = ( (2./3 - ar2  + .5*ar3)*reg1 + (1./6*(ar_m2_p3))*reg2 )
    xhat = r*(theta[1] + theta[2]*beta3)
    return (xhat,auto_gradients(xhat,r))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号