def _sample(self, n_samples):
try:
# tf.random_poisson is implemented after v1.2
random_poisson = tf.random_poisson
except AttributeError:
# This algorithm to generate random Poisson-distributed numbers is
# given by Kunth [1]
# [1]: https://en.wikipedia.org/wiki/
# Poisson_distribution#Generating_Poisson-distributed_random_variables
shape = tf.concat([[n_samples], self.batch_shape], 0)
static_n_samples = n_samples if isinstance(n_samples,
int) else None
static_shape = tf.TensorShape([static_n_samples]).concatenate(
self.get_batch_shape())
enlam = tf.exp(-self.rate)
x = tf.zeros(shape, dtype=self.dtype)
prod = tf.ones(shape, dtype=self.param_dtype)
def loop_cond(prod, x):
return tf.reduce_any(tf.greater_equal(prod, enlam))
def loop_body(prod, x):
prod *= tf.random_uniform(tf.shape(prod), minval=0, maxval=1)
x += tf.cast(tf.greater_equal(prod, enlam), dtype=self.dtype)
return prod, x
_, samples = tf.while_loop(
loop_cond, loop_body, loop_vars=[prod, x],
shape_invariants=[static_shape, static_shape])
samples.set_shape(static_shape)
else:
samples = random_poisson(self.rate, [n_samples],
dtype=self.param_dtype)
if self.param_dtype != self.dtype:
samples = tf.cast(samples, self.dtype)
return samples
评论列表
文章目录