def __init__(self,
mean=0.,
logstd=None,
std=None,
group_ndims=0,
is_reparameterized=True,
use_path_derivative=False,
check_numerics=False,
**kwargs):
self._mean = tf.convert_to_tensor(mean)
warnings.warn("FoldNormal: The order of arguments logstd/std will change "
"to std/logstd in the coming version.", FutureWarning)
if (logstd is None) == (std is None):
raise ValueError("Either std or logstd should be passed but not "
"both of them.")
elif logstd is None:
self._std = tf.convert_to_tensor(std)
dtype = assert_same_float_dtype([(self._mean, 'FoldNormal.mean'),
(self._std, 'FoldNormal.std')])
logstd = tf.log(self._std)
if check_numerics:
logstd = tf.check_numerics(logstd, "log(std)")
self._logstd = logstd
else:
# std is None
self._logstd = tf.convert_to_tensor(logstd)
dtype = assert_same_float_dtype([(self._mean, 'FoldNormal.mean'),
(self._logstd, 'FoldNormal.logstd')])
std = tf.exp(self._logstd)
if check_numerics:
std = tf.check_numerics(std, "exp(logstd)")
self._std = std
try:
tf.broadcast_static_shape(self._mean.get_shape(),
self._std.get_shape())
except ValueError:
raise ValueError(
"mean and std/logstd should be broadcastable to match each "
"other. ({} vs. {})".format(
self._mean.get_shape(), self._std.get_shape()))
self._check_numerics = check_numerics
super(FoldNormal, self).__init__(
dtype=dtype,
param_dtype=dtype,
is_continuous=True,
is_reparameterized=is_reparameterized,
use_path_derivative=use_path_derivative,
group_ndims=group_ndims,
**kwargs)
评论列表
文章目录