def test_shape_inference(self):
with self.test_session(use_gpu=True):
# Static
mean = 10 * np.random.normal(size=(10, 11, 2)).astype('d')
cov = np.zeros((10, 11, 2, 2))
dst = MultivariateNormalCholesky(
tf.constant(mean), tf.constant(cov))
self.assertEqual(dst.get_batch_shape().as_list(), [10, 11])
self.assertEqual(dst.get_value_shape().as_list(), [2])
# Dynamic
unk_mean = tf.placeholder(tf.float32, None)
unk_cov = tf.placeholder(tf.float32, None)
dst = MultivariateNormalCholesky(unk_mean, unk_cov)
self.assertEqual(dst.get_value_shape().as_list(), [None])
feed_dict = {unk_mean: np.ones(2), unk_cov: np.eye(2)}
self.assertEqual(list(dst.batch_shape.eval(feed_dict)), [])
self.assertEqual(list(dst.value_shape.eval(feed_dict)), [2])
评论列表
文章目录