def test_prob(self):
with self.fixed_randomness_session(233):
def test_prob_with(seed):
mean, cov, cov_chol = self._gen_test_params(seed)
dst = MultivariateNormalCholesky(
tf.constant(mean), tf.constant(cov_chol),
check_numerics=True)
n_exp = 200
samples = dst.sample(n_exp).eval()
log_pdf = dst.log_prob(tf.constant(samples))
pdf_shape = (n_exp, 10, 11)
self.assertEqual(log_pdf.shape.as_list(), list(pdf_shape))
log_pdf = log_pdf.eval()
self.assertEqual(log_pdf.shape, pdf_shape)
for i in range(10):
for j in range(11):
log_pdf_exact = stats.multivariate_normal.logpdf(
samples[:, i, j, :], mean[i, j], cov[i, j])
self.assertAllClose(
log_pdf_exact, log_pdf[:, i, j])
self.assertAllClose(
np.exp(log_pdf), dst.prob(tf.constant(samples)).eval())
for seed in [23, 233, 2333]:
test_prob_with(seed)
评论列表
文章目录