def test_kl_gaussian_normal(random):
"""Test Gaussian/Normal KL."""
dim = (5, 10)
Dim = (5, 10, 10)
mu0 = random.randn(*dim).astype(np.float32)
L0 = random_chol(Dim)
q = MultivariateNormalTriL(mu0, L0)
mu1 = random.randn(*dim).astype(np.float32)
std1 = 1.0
L1 = [(std1 * np.eye(dim[1])).astype(np.float32) for _ in range(dim[0])]
p = tf.distributions.Normal(mu1, std1)
KL = kl_sum(q, p)
KLr = KLdiv(mu0, L0, mu1, L1)
tc = tf.test.TestCase()
with tc.test_session():
kl = KL.eval()
assert np.isscalar(kl)
assert np.allclose(kl, KLr)
评论列表
文章目录