def __init__(self,
mean,
cov_tril,
group_ndims=0,
is_reparameterized=True,
use_path_derivative=False,
check_numerics=False,
**kwargs):
self._check_numerics = check_numerics
self._mean = tf.convert_to_tensor(mean)
self._mean = assert_rank_at_least_one(
self._mean, 'MultivariateNormalCholesky.mean')
self._n_dim = get_shape_at(self._mean, -1)
self._cov_tril = tf.convert_to_tensor(cov_tril)
self._cov_tril = assert_rank_at_least(
self._cov_tril, 2, 'MultivariateNormalCholesky.cov_tril')
# Static shape check
expected_shape = self._mean.get_shape().concatenate(
[self._n_dim if isinstance(self._n_dim, int) else None])
self._cov_tril.get_shape().assert_is_compatible_with(expected_shape)
# Dynamic
expected_shape = tf.concat(
[tf.shape(self._mean), [self._n_dim]], axis=0)
actual_shape = tf.shape(self._cov_tril)
msg = ['MultivariateNormalCholesky.cov_tril should have compatible '
'shape with mean. Expected', expected_shape, ' got ',
actual_shape]
assert_ops = [tf.assert_equal(expected_shape, actual_shape, msg)]
with tf.control_dependencies(assert_ops):
self._cov_tril = tf.identity(self._cov_tril)
dtype = assert_same_float_dtype(
[(self._mean, 'MultivariateNormalCholesky.mean'),
(self._cov_tril, 'MultivariateNormalCholesky.cov_tril')])
super(MultivariateNormalCholesky, self).__init__(
dtype=dtype,
param_dtype=dtype,
is_continuous=True,
is_reparameterized=is_reparameterized,
use_path_derivative=use_path_derivative,
group_ndims=group_ndims,
**kwargs)
评论列表
文章目录