def testNormalEntropy(self):
with self.test_session():
mu_v = np.array([1.0, 1.0, 1.0])
sigma_v = np.array([[1.0, 2.0, 3.0]]).T
normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
# scipy.stats.norm cannot deal with these shapes.
sigma_broadcast = mu_v * sigma_v
expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**
2)
entropy = normal.entropy()
np.testing.assert_allclose(expected_entropy, entropy.eval())
self.assertAllEqual(normal.batch_shape().eval(), entropy.get_shape())
self.assertAllEqual(normal.batch_shape().eval(), entropy.eval().shape)
self.assertAllEqual(normal.get_batch_shape(), entropy.get_shape())
self.assertAllEqual(normal.get_batch_shape(), entropy.eval().shape)
normal_test.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录