def testNormalLogCdfAndLogSurvivalFunction(self):
# At integer values, the result should be the same as the standard normal.
batch_shape = (3, 3)
mu = rng.randn(*batch_shape)
sigma = rng.rand(*batch_shape) + 1.0
with self.test_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
loc=mu, scale=sigma))
sp_normal = stats.norm(mu, sigma)
x = rng.randint(-10, 10, size=batch_shape).astype(np.float64)
self.assertAllClose(sp_normal.logcdf(x), qdist.log_cdf(x).eval())
self.assertAllClose(
sp_normal.logsf(x), qdist.log_survival_function(x).eval())
quantized_distribution_test.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录