def assert_same_float_and_int_dtype(tensors_with_name, dtype=None):
"""
Whether all types of tensors in `tensors` are the same and floating (or
integer) type.
:param tensors_with_name: A list of (tensor, tensor_name).
:param dtype: Expected type. If `None`, depend on the type of tensors.
:return: The type of `tensors`.
"""
available_types = [tf.float16, tf.float32, tf.float64,
tf.int16, tf.int32, tf.int64]
if dtype is None:
return assert_same_specific_dtype(tensors_with_name, available_types)
elif dtype in available_types:
return assert_same_dtype(tensors_with_name, dtype)
else:
raise TypeError("The argument 'dtype' must be in %s" % available_types)
评论列表
文章目录