def __init__(self,
alpha,
group_ndims=0,
check_numerics=False,
**kwargs):
self._alpha = tf.convert_to_tensor(alpha)
dtype = assert_same_float_dtype(
[(self._alpha, 'Dirichlet.alpha')])
static_alpha_shape = self._alpha.get_shape()
shape_err_msg = "alpha should have rank >= 1."
cat_err_msg = "n_categories (length of the last axis " \
"of alpha) should be at least 2."
if static_alpha_shape and (static_alpha_shape.ndims < 1):
raise ValueError(shape_err_msg)
elif static_alpha_shape and (
static_alpha_shape[-1].value is not None):
self._n_categories = static_alpha_shape[-1].value
if self._n_categories < 2:
raise ValueError(cat_err_msg)
else:
_assert_shape_op = tf.assert_rank_at_least(
self._alpha, 1, message=shape_err_msg)
with tf.control_dependencies([_assert_shape_op]):
self._alpha = tf.identity(self._alpha)
self._n_categories = tf.shape(self._alpha)[-1]
_assert_cat_op = tf.assert_greater_equal(
self._n_categories, 2, message=cat_err_msg)
with tf.control_dependencies([_assert_cat_op]):
self._alpha = tf.identity(self._alpha)
self._check_numerics = check_numerics
super(Dirichlet, self).__init__(
dtype=dtype,
param_dtype=dtype,
is_continuous=True,
is_reparameterized=False,
group_ndims=group_ndims,
**kwargs)
评论列表
文章目录