def testNormalCDF(self):
with self.test_session():
batch_size = 50
mu = self._rng.randn(batch_size)
sigma = self._rng.rand(batch_size) + 1.0
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_cdf = stats.norm(mu, sigma).cdf(x)
cdf = normal.cdf(x)
self.assertAllClose(expected_cdf, cdf.eval(), atol=0)
self.assertAllEqual(normal.batch_shape().eval(), cdf.get_shape())
self.assertAllEqual(normal.batch_shape().eval(), cdf.eval().shape)
self.assertAllEqual(normal.get_batch_shape(), cdf.get_shape())
self.assertAllEqual(normal.get_batch_shape(), cdf.eval().shape)
normal_test.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录