def _sample(self, n_samples):
minval, maxval = self.minval, self.maxval
if not self.is_reparameterized:
minval = tf.stop_gradient(minval)
maxval = tf.stop_gradient(maxval)
shape = tf.concat([[n_samples], self.batch_shape], 0)
samples = tf.random_uniform(shape, 0, 1, dtype=self.dtype) * \
(maxval - minval) + minval
static_n_samples = n_samples if isinstance(n_samples, int) else None
samples.set_shape(
tf.TensorShape([static_n_samples]).concatenate(
self.get_batch_shape()))
return samples
评论列表
文章目录