utils.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def assert_scalar(tensor, name):
    """
    Whether the `tensor` is a scalar (0-D tensor).

    :param tensor: A tensor to be checked.
    :param name: The name of `tensor` for error message.
    :return: The checked tensor.
    """
    static_shape = tensor.get_shape()
    shape_err_msg = name + " should be a scalar (0-D tensor)."
    if static_shape and (static_shape.ndims >= 1):
        raise ValueError(shape_err_msg)
    else:
        _assert_shape_op = tf.assert_rank(tensor, 0, message=shape_err_msg)
        with tf.control_dependencies([_assert_shape_op]):
            tensor = tf.identity(tensor)
        return tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号