def __init__(self,
dtype,
param_dtype,
is_continuous,
is_reparameterized,
use_path_derivative=False,
group_ndims=0,
**kwargs):
if 'group_event_ndims' in kwargs:
warnings.warn(
"The argument `group_event_ndims` has been deprecated and "
"will be removed in the coming version (0.3.1). Please use "
"`group_ndims` instead.", FutureWarning)
group_ndims = kwargs['group_event_ndims']
self._dtype = dtype
self._param_dtype = param_dtype
self._is_continuous = is_continuous
self._is_reparameterized = is_reparameterized
self._use_path_derivative = use_path_derivative
if isinstance(group_ndims, int):
if group_ndims < 0:
raise ValueError("group_ndims must be non-negative.")
self._group_ndims = group_ndims
else:
group_ndims = tf.convert_to_tensor(group_ndims, tf.int32)
_assert_rank_op = tf.assert_rank(
group_ndims, 0,
message="group_ndims should be a scalar (0-D Tensor).")
_assert_nonnegative_op = tf.assert_greater_equal(
group_ndims, 0,
message="group_ndims must be non-negative.")
with tf.control_dependencies([_assert_rank_op,
_assert_nonnegative_op]):
self._group_ndims = tf.identity(group_ndims)
评论列表
文章目录