def normalize_num_type(num_type):
"""
Work out what a sensible type for the array is. if the default type
is float32, downcast 64bit float to float32. For ints, assume int32
"""
if isinstance(num_type, tf.DType):
num_type = num_type.as_numpy_dtype.type
if num_type in [np.float32, np.float64]: # pylint: disable=E1101
num_type = settings.float_type
elif num_type in [np.int16, np.int32, np.int64]:
num_type = settings.int_type
else:
raise ValueError('Unknown dtype "{0}" passed to normalizer.'.format(num_type))
return num_type
# def types_array(tensor, shape=None):
# shape = shape if shape is not None else tensor.shape.as_list()
# return np.full(shape, tensor.dtype).tolist()
评论列表
文章目录