hmc.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def update(self, x):
        # x: (chain_dims data_dims)
        new_t = tf.assign(self.t, self.t + 1)
        weight = (1 - self.decay) / (1 - tf.pow(self.decay, new_t))
        # incr: (chain_dims data_dims)
        incr = [weight * (q - mean) for q, mean in zip(x, self.mean)]
        # mean: (1,...,1 data_dims)
        update_mean = [mean.assign_add(
            tf.reduce_mean(i, axis=self.chain_axes, keep_dims=True))
            for mean, i in zip(self.mean, incr)]
        # var: (1,...,1 data_dims)
        new_var = [
            (1 - weight) * var +
            tf.reduce_mean(i * (q - mean), axis=self.chain_axes,
                           keep_dims=True)
            for var, i, q, mean in zip(self.var, incr, x, update_mean)]

        update_var = [tf.assign(var, n_var)
                      for var, n_var in zip(self.var, new_var)]
        return update_var
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号