def testDirichletSample(self):
with self.test_session():
alpha = [1., 2]
dirichlet = dirichlet_lib.Dirichlet(alpha)
n = constant_op.constant(100000)
samples = dirichlet.sample(n)
sample_values = samples.eval()
self.assertEqual(sample_values.shape, (100000, 2))
self.assertTrue(np.all(sample_values > 0.0))
self.assertLess(
stats.kstest(
# Beta is a univariate distribution.
sample_values[:, 0],
stats.beta(
a=1., b=2.).cdf)[0],
0.01)
dirichlet_test.py 文件源码
python
阅读 59
收藏 0
点赞 0
评论 0
评论列表
文章目录