def _sample_n(self, n, seed=None):
new_shape = array_ops.concat(0, ([n], self.batch_shape()))
uniform = random_ops.random_uniform(
new_shape, seed=seed, dtype=self.p.dtype)
sample = math_ops.less(uniform, self.p)
return math_ops.cast(sample, self.dtype)
评论列表
文章目录