def testNormalLogCDF(self):
with self.test_session():
batch_size = 50
mu = self._rng.randn(batch_size)
sigma = self._rng.rand(batch_size) + 1.0
x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_cdf = stats.norm(mu, sigma).logcdf(x)
cdf = normal.log_cdf(x)
self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5)
self.assertAllEqual(normal.batch_shape().eval(), cdf.get_shape())
self.assertAllEqual(normal.batch_shape().eval(), cdf.eval().shape)
self.assertAllEqual(normal.get_batch_shape(), cdf.get_shape())
self.assertAllEqual(normal.get_batch_shape(), cdf.eval().shape)
normal_test.py 文件源码
python
阅读 36
收藏 0
点赞 0
评论 0
评论列表
文章目录