utils.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def assert_rank_at_least(tensor, k, name):
    """
    Whether the rank of `tensor` is at least k.

    :param tensor: A tensor to be checked.
    :param k: The least rank allowed.
    :param name: The name of `tensor` for error message.
    :return: The checked tensor.
    """
    static_shape = tensor.get_shape()
    shape_err_msg = '{} should have rank >= {}.'.format(name, k)
    if static_shape and (static_shape.ndims < k):
        raise ValueError(shape_err_msg)
    if not static_shape:
        _assert_shape_op = tf.assert_rank_at_least(
            tensor, k, message=shape_err_msg)
        with tf.control_dependencies([_assert_shape_op]):
            tensor = tf.identity(tensor)
    return tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号