shrinkage.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:onsager_deep_learning 作者: mborgerding 项目源码 文件源码
def show_shrinkage(shrink_func,theta,**kwargs):
    tf.reset_default_graph()
    tf.set_random_seed(kwargs.get('seed',1) )

    N = kwargs.get('N',500)
    L = kwargs.get('L',4)
    nsigmas = kwargs.get('sigmas',10)
    shape = (N,L)
    rvar = 1e-4
    r = np.reshape( np.linspace(0,nsigmas,N*L)*math.sqrt(rvar),shape)
    r_ = tfcf(r)
    rvar_ = tfcf(np.ones(L)*rvar)

    xhat_,dxdr_ = shrink_func(r_,rvar_ ,tfcf(theta))

    with tf.Session() as sess:
        sess.run( tf.global_variables_initializer() )
        xhat = sess.run(xhat_)
    import matplotlib.pyplot as plt
    plt.figure(1)
    plt.plot(r.reshape(-1),r.reshape(-1),'y')
    plt.plot(r.reshape(-1),xhat.reshape(-1),'b')
    if kwargs.has_key('title'):
        plt.suptitle(kwargs['title'])
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号