test_utils.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def test_log_mean_exp(self):
        with self.test_session(use_gpu=True) as sess:
            a = np.array([[[1., 3., 0.2], [0.7, 2., 1e-6]],
                          [[0., 1e6, 1.], [1., 1., 1.]]])
            for keepdims in [True, False]:
                true_values = misc.logsumexp(a, (0, 2), keepdims=keepdims) - \
                              np.log(a.shape[0] * a.shape[2])
                test_values = sess.run(log_mean_exp(
                    tf.constant(a), (0, 2), keepdims))
                self.assertAllClose(test_values, true_values)

            b = np.array([[0., 1e-6, 10.1]])
            test_values = sess.run(log_mean_exp(b, 0, keep_dims=False))
            self.assertTrue(np.abs(test_values - b).max() < 1e-6)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号