def _sample(self, n_samples):
mean, std = self.mean, self.std
if not self.is_reparameterized:
mean = tf.stop_gradient(mean)
std = tf.stop_gradient(std)
shape = tf.concat([[n_samples], self.batch_shape], 0)
samples = tf.random_normal(shape, dtype=self.dtype) * std + mean
static_n_samples = n_samples if isinstance(n_samples, int) else None
samples.set_shape(
tf.TensorShape([static_n_samples]).concatenate(
self.get_batch_shape()))
return samples
评论列表
文章目录