def cast_dtype(dtype, target):
"""Changes float dtypes to the target dtype, leaves others unchanged.
Used to map all float values to a target precision. Also casts numpy
dtypes to TensorFlow dtypes.
Parameters
----------
dtype : ``tf.DType`` or :class:`~numpy:numpy.dtype`
Input dtype to be converted
target : ``tf.DType``
Floating point dtype to which all floating types should be converted
Returns
-------
``tf.DType``
Input dtype, converted to ``target`` type if necessary
"""
if not isinstance(dtype, tf.DType):
dtype = tf.as_dtype(dtype)
if dtype.is_floating:
dtype = target
return dtype
评论列表
文章目录