test_utils.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def test_log_combination(self):
        with self.test_session(use_gpu=True):
            def _test_func(n, ks):
                tf_n = tf.convert_to_tensor(n, tf.float32)
                tf_ks = tf.convert_to_tensor(ks, tf.float32)
                true_value = np.log(misc.factorial(n)) - \
                    np.sum(np.log(misc.factorial(ks)), axis=-1)
                test_value = log_combination(tf_n, tf_ks).eval()
                self.assertAllClose(true_value, test_value)

            _test_func(10, [1, 2, 3, 4])
            _test_func([1, 2], [[1], [2]])
            _test_func([1, 4], [[1, 0], [2, 2]])
            _test_func([[2], [3]], [[[0, 2], [1, 2]]])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号