def shrink_spline(r,rvar,theta):
""" Spline-based shrinkage function
"""
scale = theta[0]*tf.sqrt(rvar)
rs = tf.sign(r)
ar = tf.abs(r/scale)
ar2 = tf.square(ar)
ar3 = ar*ar2
reg1 = tf.to_float(ar<1)
reg2 = tf.to_float(ar<2)-reg1
ar_m2 = 2-ar
ar_m2_p2 = tf.square(ar_m2)
ar_m2_p3 = ar_m2*ar_m2_p2
beta3 = ( (2./3 - ar2 + .5*ar3)*reg1 + (1./6*(ar_m2_p3))*reg2 )
xhat = r*(theta[1] + theta[2]*beta3)
return (xhat,auto_gradients(xhat,r))
评论列表
文章目录