test_multivariate.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def test_value(self):
        with self.test_session(use_gpu=True):
            def _test_value(logits, n_experiments, given):
                logits = np.array(logits, np.float32)
                normalized_logits = logits - misc.logsumexp(
                    logits, axis=-1, keepdims=True)
                given = np.array(given)
                dist = Multinomial(logits, n_experiments)
                log_p = dist.log_prob(given)
                target_log_p = np.log(misc.factorial(n_experiments)) - \
                    np.sum(np.log(misc.factorial(given)), -1) + \
                    np.sum(given * normalized_logits, -1)
                self.assertAllClose(log_p.eval(), target_log_p)
                p = dist.prob(given)
                target_p = np.exp(target_log_p)
                self.assertAllClose(p.eval(), target_p)

            _test_value([-50., -20., 0.], 4, [1, 0, 3])
            _test_value([1., 10., 1000.], 1, [1, 0, 0])
            _test_value([[2., 3., 1.], [5., 7., 4.]], 3,
                        np.ones([3, 1, 3], dtype=np.int32))
            _test_value([-10., 10., 20., 50.], 100, [[0, 1, 99, 100],
                                                     [100, 99, 1, 0]])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号