def assert_scalar(tensor, name):
"""
Whether the `tensor` is a scalar (0-D tensor).
:param tensor: A tensor to be checked.
:param name: The name of `tensor` for error message.
:return: The checked tensor.
"""
static_shape = tensor.get_shape()
shape_err_msg = name + " should be a scalar (0-D tensor)."
if static_shape and (static_shape.ndims >= 1):
raise ValueError(shape_err_msg)
else:
_assert_shape_op = tf.assert_rank(tensor, 0, message=shape_err_msg)
with tf.control_dependencies([_assert_shape_op]):
tensor = tf.identity(tensor)
return tensor
评论列表
文章目录