def _sample(self, n_samples):
mean, cov_tril = self.mean, self.cov_tril
if not self.is_reparameterized:
mean = tf.stop_gradient(mean)
cov_tril = tf.stop_gradient(cov_tril)
def tile(t):
new_shape = tf.concat([[n_samples], tf.ones_like(tf.shape(t))], 0)
return tf.tile(tf.expand_dims(t, 0), new_shape)
batch_mean = tile(mean)
batch_cov = tile(cov_tril)
# n_dim -> n_dim x 1 for matmul
batch_mean = tf.expand_dims(batch_mean, -1)
noise = tf.random_normal(tf.shape(batch_mean), dtype=self.dtype)
samples = tf.matmul(batch_cov, noise) + batch_mean
samples = tf.squeeze(samples, -1)
# Update static shape
static_n_samples = n_samples if isinstance(n_samples, int) else None
samples.set_shape(tf.TensorShape([static_n_samples])
.concatenate(self.get_batch_shape())
.concatenate(self.get_value_shape()))
return samples
评论列表
文章目录