test_multivariate.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def test_init_check_shape(self):
        with self.test_session(use_gpu=True):
            with self.assertRaisesRegexp(ValueError, "should have rank"):
                MultivariateNormalCholesky(tf.zeros([]), tf.zeros([]))
            with self.assertRaisesRegexp(ValueError, "should have rank"):
                MultivariateNormalCholesky(tf.zeros([1]), tf.zeros([1]))
            with self.assertRaisesRegexp(ValueError, 'compatible'):
                MultivariateNormalCholesky(
                    tf.zeros([1, 2]), tf.placeholder(tf.float32, [1, 2, 3]))
            u = tf.placeholder(tf.float32, [None])
            len_u = tf.shape(u)[0]
            dst = MultivariateNormalCholesky(
                tf.zeros([2]), tf.zeros([len_u, len_u]))
            with self.assertRaisesRegexp(
                    tf.errors.InvalidArgumentError, 'compatible'):
                dst.sample().eval(feed_dict={u: np.ones((3,))})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号