def shrink_soft_threshold(r,rvar,theta):
"""
soft threshold function
y=sign(x)*max(0,abs(x)-theta[0]*sqrt(rvar) )*scaling
where scaling is theta[1] (default=1)
in other words, if theta is len(1), then the standard
"""
if len(theta.get_shape())>0 and theta.get_shape() != (1,):
lam = theta[0] * tf.sqrt(rvar)
scale=theta[1]
else:
lam = theta * tf.sqrt(rvar)
scale = None
lam = tf.maximum(lam,0)
arml = tf.abs(r) - lam
xhat = tf.sign(r) * tf.maximum(arml,0)
dxdr = tf.reduce_mean(tf.to_float(arml>0),0)
if scale is not None:
xhat = xhat*scale
dxdr = dxdr*scale
return (xhat,dxdr)
评论列表
文章目录