def assert_rank_at_least(tensor, k, name):
"""
Whether the rank of `tensor` is at least k.
:param tensor: A tensor to be checked.
:param k: The least rank allowed.
:param name: The name of `tensor` for error message.
:return: The checked tensor.
"""
static_shape = tensor.get_shape()
shape_err_msg = '{} should have rank >= {}.'.format(name, k)
if static_shape and (static_shape.ndims < k):
raise ValueError(shape_err_msg)
if not static_shape:
_assert_shape_op = tf.assert_rank_at_least(
tensor, k, message=shape_err_msg)
with tf.control_dependencies([_assert_shape_op]):
tensor = tf.identity(tensor)
return tensor
评论列表
文章目录