utils.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def assert_positive_integer(value, dtype, name):
    """
    Whether `value` is a scalar (or 0-D tensor) and positive.
    If `value` is the instance of built-in type, it will be checked
    directly. Otherwise, it will be converted to a `dtype` tensor and checked.

    :param value: The value to be checked.
    :param dtype: The tensor dtype.
    :param name: The name of `value` used in error message.
    :return: The checked value.
    """
    sign_err_msg = name + " must be positive"
    if isinstance(value, (int, float)):
        if value <= 0:
            raise ValueError(sign_err_msg)
        return value
    else:
        try:
            tensor = tf.convert_to_tensor(value, dtype)
        except ValueError:
            raise TypeError(name + ' must be ' + str(dtype))
        _assert_rank_op = tf.assert_rank(
            tensor, 0,
            message=name + " should be a scalar (0-D Tensor).")
        _assert_positive_op = tf.assert_greater(
            tensor, tf.constant(0, dtype), message=sign_err_msg)
        with tf.control_dependencies([_assert_rank_op,
                                      _assert_positive_op]):
            tensor = tf.identity(tensor)
        return tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号