def _testPDFShapes(self, mvn_dist, mu, sigma):
with self.test_session() as sess:
mvn = mvn_dist(mu, sigma)
x = 2 * array_ops.ones_like(mu)
log_pdf = mvn.log_prob(x)
pdf = mvn.prob(x)
mu_value = np.ones([3, 3, 2])
sigma_value = np.zeros([3, 3, 2, 2])
sigma_value[:] = np.identity(2)
x_value = 2. * np.ones([3, 3, 2])
feed_dict = {mu: mu_value, sigma: sigma_value}
scipy_mvn = stats.multivariate_normal(
mean=mu_value[(0, 0)], cov=sigma_value[(0, 0)])
expected_log_pdf = scipy_mvn.logpdf(x_value[(0, 0)])
expected_pdf = scipy_mvn.pdf(x_value[(0, 0)])
log_pdf_evaled, pdf_evaled = sess.run([log_pdf, pdf], feed_dict=feed_dict)
self.assertAllEqual([3, 3], log_pdf_evaled.shape)
self.assertAllEqual([3, 3], pdf_evaled.shape)
self.assertAllClose(expected_log_pdf, log_pdf_evaled[0, 0])
self.assertAllClose(expected_pdf, pdf_evaled[0, 0])
mvn_test.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录