def test_init_check_shape(self):
with self.test_session(use_gpu=True):
with self.assertRaisesRegexp(ValueError, "should have rank"):
MultivariateNormalCholesky(tf.zeros([]), tf.zeros([]))
with self.assertRaisesRegexp(ValueError, "should have rank"):
MultivariateNormalCholesky(tf.zeros([1]), tf.zeros([1]))
with self.assertRaisesRegexp(ValueError, 'compatible'):
MultivariateNormalCholesky(
tf.zeros([1, 2]), tf.placeholder(tf.float32, [1, 2, 3]))
u = tf.placeholder(tf.float32, [None])
len_u = tf.shape(u)[0]
dst = MultivariateNormalCholesky(
tf.zeros([2]), tf.zeros([len_u, len_u]))
with self.assertRaisesRegexp(
tf.errors.InvalidArgumentError, 'compatible'):
dst.sample().eval(feed_dict={u: np.ones((3,))})
评论列表
文章目录