test_distributions.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def test_beta_log_prob(self):
        for _ in range(100):
            alpha = np.exp(np.random.normal())
            beta = np.exp(np.random.normal())
            dist = Beta(alpha, beta)
            x = dist.sample()
            actual_log_prob = dist.log_prob(x).sum()
            expected_log_prob = scipy.stats.beta.logpdf(x, alpha, beta)
            self.assertAlmostEqual(actual_log_prob, expected_log_prob, places=3)

    # This is a randomized test.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号