base.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def __init__(self,
                 dtype,
                 param_dtype,
                 is_continuous,
                 is_reparameterized,
                 use_path_derivative=False,
                 group_ndims=0,
                 **kwargs):
        if 'group_event_ndims' in kwargs:
            warnings.warn(
                "The argument `group_event_ndims` has been deprecated and "
                "will be removed in the coming version (0.3.1). Please use "
                "`group_ndims` instead.", FutureWarning)
            group_ndims = kwargs['group_event_ndims']

        self._dtype = dtype
        self._param_dtype = param_dtype
        self._is_continuous = is_continuous
        self._is_reparameterized = is_reparameterized
        self._use_path_derivative = use_path_derivative
        if isinstance(group_ndims, int):
            if group_ndims < 0:
                raise ValueError("group_ndims must be non-negative.")
            self._group_ndims = group_ndims
        else:
            group_ndims = tf.convert_to_tensor(group_ndims, tf.int32)
            _assert_rank_op = tf.assert_rank(
                group_ndims, 0,
                message="group_ndims should be a scalar (0-D Tensor).")
            _assert_nonnegative_op = tf.assert_greater_equal(
                group_ndims, 0,
                message="group_ndims must be non-negative.")
            with tf.control_dependencies([_assert_rank_op,
                                          _assert_nonnegative_op]):
                self._group_ndims = tf.identity(group_ndims)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号