def check_dtype(dtype, func_str, a, n):
if np.isscalar(a) or not a.shape:
if func_str not in ("sum", "prod", "len"):
raise ValueError("scalar inputs are supported only for 'sum', "
"'prod' and 'len'")
a_dtype = np.dtype(type(a))
else:
a_dtype = a.dtype
if dtype is not None:
# dtype set by the user
# Careful here: np.bool != np.bool_ !
if np.issubdtype(dtype, np.bool_) and \
not('all' in func_str or 'any' in func_str):
raise TypeError("function %s requires a more complex datatype "
"than bool" % func_str)
if not np.issubdtype(dtype, np.integer) and func_str in ('len', 'nanlen'):
raise TypeError("function %s requires an integer datatype" % func_str)
# TODO: Maybe have some more checks here
return np.dtype(dtype)
else:
try:
return np.dtype(_forced_types[func_str])
except KeyError:
if func_str in _forced_float_types:
if np.issubdtype(a_dtype, np.floating):
return a_dtype
else:
return np.dtype(np.float64)
else:
if func_str == 'sum':
# Try to guess the minimally required int size
if np.issubdtype(a_dtype, np.int64):
# It's not getting bigger anymore
# TODO: strictly speaking it might need float
return np.dtype(np.int64)
elif np.issubdtype(a_dtype, np.integer):
maxval = np.iinfo(a_dtype).max * n
return minimum_dtype(maxval, a_dtype)
elif np.issubdtype(a_dtype, np.bool_):
return minimum_dtype(n, a_dtype)
else:
# floating, inexact, whatever
return a_dtype
elif func_str in _forced_same_type:
return a_dtype
else:
if isinstance(a_dtype, np.integer):
return np.dtype(np.int64)
else:
return a_dtype
评论列表
文章目录