def show_shrinkage(shrink_func,theta,**kwargs):
tf.reset_default_graph()
tf.set_random_seed(kwargs.get('seed',1) )
N = kwargs.get('N',500)
L = kwargs.get('L',4)
nsigmas = kwargs.get('sigmas',10)
shape = (N,L)
rvar = 1e-4
r = np.reshape( np.linspace(0,nsigmas,N*L)*math.sqrt(rvar),shape)
r_ = tfcf(r)
rvar_ = tfcf(np.ones(L)*rvar)
xhat_,dxdr_ = shrink_func(r_,rvar_ ,tfcf(theta))
with tf.Session() as sess:
sess.run( tf.global_variables_initializer() )
xhat = sess.run(xhat_)
import matplotlib.pyplot as plt
plt.figure(1)
plt.plot(r.reshape(-1),r.reshape(-1),'y')
plt.plot(r.reshape(-1),xhat.reshape(-1),'b')
if kwargs.has_key('title'):
plt.suptitle(kwargs['title'])
plt.show()
评论列表
文章目录