def testExponentialSampleMultiDimensional(self):
with self.test_session():
batch_size = 2
lam_v = [3.0, 22.0]
lam = constant_op.constant([lam_v] * batch_size)
exponential = exponential_lib.Exponential(lam=lam)
n = 100000
samples = exponential.sample(n, seed=138)
self.assertEqual(samples.get_shape(), (n, batch_size, 2))
sample_values = samples.eval()
self.assertFalse(np.any(sample_values < 0.0))
for i in range(2):
self.assertLess(
stats.kstest(
sample_values[:, 0, i],
stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
0.01)
self.assertLess(
stats.kstest(
sample_values[:, 1, i],
stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
0.01)
exponential_test.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录