def assert_same_float_dtype(tensors_with_name, dtype=None):
"""
Whether all types of tensors in `tensors` are the same and floating type.
:param tensors_with_name: A list of (tensor, tensor_name).
:param dtype: Expected type. If `None`, depend on the type of tensors.
:return: The type of `tensors`.
"""
floating_types = [tf.float16, tf.float32, tf.float64]
if dtype is None:
return assert_same_specific_dtype(tensors_with_name, floating_types)
elif dtype in floating_types:
return assert_same_dtype(tensors_with_name, dtype)
else:
raise TypeError("The argument 'dtype' must be in %s" % floating_types)
评论列表
文章目录