def test_sample(self):
with self.fixed_randomness_session(233):
def test_sample_with(seed):
mean, cov, cov_chol = self._gen_test_params(seed)
dst = MultivariateNormalCholesky(
tf.constant(mean), tf.constant(cov_chol))
n_exp = 20000
samples = dst.sample(n_exp)
sample_shape = (n_exp, 10, 11, 3)
self.assertEqual(samples.shape.as_list(), list(sample_shape))
samples = dst.sample(n_exp).eval()
self.assertEqual(samples.shape, sample_shape)
self.assertAllClose(
np.mean(samples, axis=0), mean, rtol=5e-2, atol=5e-2)
for i in range(10):
for j in range(11):
self.assertAllClose(
np.cov(samples[:, i, j, :].T), cov[i, j],
rtol=1e-1, atol=1e-1)
for seed in [23, 233, 2333]:
test_sample_with(seed)
评论列表
文章目录