test_multivariate.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def test_value(self):
        with self.test_session(use_gpu=True):
            def _test_value(given, temperature, logits):
                given = np.array(given, np.float32)
                logits = np.array(logits, np.float32)
                n = logits.shape[-1]

                t = temperature
                target_log_p = special.gammaln(n) + (n - 1) * np.log(t) + \
                    (logits - t * given).sum(axis=-1) - \
                    n * np.log(np.exp(logits - t * given).sum(axis=-1))

                con = ExpConcrete(temperature, logits=logits)
                log_p = con.log_prob(given)
                self.assertAllClose(log_p.eval(), target_log_p)
                p = con.prob(given)
                self.assertAllClose(p.eval(), np.exp(target_log_p))

            _test_value([np.log(0.25), np.log(0.25), np.log(0.5)],
                        0.1,
                        [1., 1., 1.2])
            _test_value([[np.log(0.25), np.log(0.25), np.log(0.5)],
                        [np.log(0.1), np.log(0.5), np.log(0.4)]],
                        0.5,
                        [[1., 1., 1.], [.5, .5, .4]])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号