univariate.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def _sample(self, n_samples):
        try:
            # tf.random_poisson is implemented after v1.2
            random_poisson = tf.random_poisson
        except AttributeError:
            # This algorithm to generate random Poisson-distributed numbers is
            # given by Kunth [1]
            # [1]: https://en.wikipedia.org/wiki/
            #      Poisson_distribution#Generating_Poisson-distributed_random_variables
            shape = tf.concat([[n_samples], self.batch_shape], 0)
            static_n_samples = n_samples if isinstance(n_samples,
                                                       int) else None
            static_shape = tf.TensorShape([static_n_samples]).concatenate(
                self.get_batch_shape())
            enlam = tf.exp(-self.rate)
            x = tf.zeros(shape, dtype=self.dtype)
            prod = tf.ones(shape, dtype=self.param_dtype)

            def loop_cond(prod, x):
                return tf.reduce_any(tf.greater_equal(prod, enlam))

            def loop_body(prod, x):
                prod *= tf.random_uniform(tf.shape(prod), minval=0, maxval=1)
                x += tf.cast(tf.greater_equal(prod, enlam), dtype=self.dtype)
                return prod, x

            _, samples = tf.while_loop(
                loop_cond, loop_body, loop_vars=[prod, x],
                shape_invariants=[static_shape, static_shape])

            samples.set_shape(static_shape)
        else:
            samples = random_poisson(self.rate, [n_samples],
                                     dtype=self.param_dtype)
            if self.param_dtype != self.dtype:
                samples = tf.cast(samples, self.dtype)
        return samples
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号