def assert_positive_integer(value, dtype, name):
"""
Whether `value` is a scalar (or 0-D tensor) and positive.
If `value` is the instance of built-in type, it will be checked
directly. Otherwise, it will be converted to a `dtype` tensor and checked.
:param value: The value to be checked.
:param dtype: The tensor dtype.
:param name: The name of `value` used in error message.
:return: The checked value.
"""
sign_err_msg = name + " must be positive"
if isinstance(value, (int, float)):
if value <= 0:
raise ValueError(sign_err_msg)
return value
else:
try:
tensor = tf.convert_to_tensor(value, dtype)
except ValueError:
raise TypeError(name + ' must be ' + str(dtype))
_assert_rank_op = tf.assert_rank(
tensor, 0,
message=name + " should be a scalar (0-D Tensor).")
_assert_positive_op = tf.assert_greater(
tensor, tf.constant(0, dtype), message=sign_err_msg)
with tf.control_dependencies([_assert_rank_op,
_assert_positive_op]):
tensor = tf.identity(tensor)
return tensor
评论列表
文章目录