statistics.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:a-nice-mc 作者: ermongroup 项目源码 文件源码
def __init__(self, energy_fn, prior, std=1.0,
                 inter_op_parallelism_threads=1, intra_op_parallelism_threads=1):
        self.energy_fn = energy_fn
        self.prior = prior
        self.z = self.energy_fn.z

        def fn(z, x):
            z_ = z + tf.random_normal(tf.shape(self.z), 0.0, std)
            accept = metropolis_hastings_accept(
                energy_prev=energy_fn(z),
                energy_next=energy_fn(z_)
            )
            return tf.where(accept, z_, z)

        self.steps = tf.placeholder(tf.int32, [])
        elems = tf.zeros([self.steps])
        self.z_ = tf.scan(
            fn, elems, self.z, back_prop=False
        )

        self.sess = tf.Session(
            config=tf.ConfigProto(
                inter_op_parallelism_threads=inter_op_parallelism_threads,
                intra_op_parallelism_threads=intra_op_parallelism_threads
            )
        )
        self.sess.run(tf.global_variables_initializer())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号