def testGammaPdfOfSampleMultiDims(self):
with session.Session() as sess:
gamma = gamma_lib.Gamma(alpha=[7., 11.], beta=[[5.], [6.]])
num = 50000
samples = gamma.sample(num, seed=137)
pdfs = gamma.prob(samples)
sample_vals, pdf_vals = sess.run([samples, pdfs])
self.assertEqual(samples.get_shape(), (num, 2, 2))
self.assertEqual(pdfs.get_shape(), (num, 2, 2))
self.assertAllClose(
stats.gamma.mean(
[[7., 11.], [7., 11.]], scale=1 / np.array([[5., 5.], [6., 6.]])),
sample_vals.mean(axis=0),
atol=.1)
self.assertAllClose(
stats.gamma.var([[7., 11.], [7., 11.]],
scale=1 / np.array([[5., 5.], [6., 6.]])),
sample_vals.var(axis=0),
atol=.1)
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
gamma_test.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录