test_multivariate.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def test_shape_inference(self):
        with self.test_session(use_gpu=True):
            # Static
            mean = 10 * np.random.normal(size=(10, 11, 2)).astype('d')
            cov = np.zeros((10, 11, 2, 2))
            dst = MultivariateNormalCholesky(
                tf.constant(mean), tf.constant(cov))
            self.assertEqual(dst.get_batch_shape().as_list(), [10, 11])
            self.assertEqual(dst.get_value_shape().as_list(), [2])
            # Dynamic
            unk_mean = tf.placeholder(tf.float32, None)
            unk_cov = tf.placeholder(tf.float32, None)
            dst = MultivariateNormalCholesky(unk_mean, unk_cov)
            self.assertEqual(dst.get_value_shape().as_list(), [None])
            feed_dict = {unk_mean: np.ones(2), unk_cov: np.eye(2)}
            self.assertEqual(list(dst.batch_shape.eval(feed_dict)), [])
            self.assertEqual(list(dst.value_shape.eval(feed_dict)), [2])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号