def auto_dtype(A, B):
"""
Get promoted datatype for A and B combined.
Parameters
----------
A : ndarray
B : ndarray
Returns
-------
precision : dtype
Datatype that would be used after appplying NumPy type promotion rules.
If its not float dtype, e.g. int dtype, output is `float32` dtype.
"""
# Datatype that would be used after appplying NumPy type promotion rules
precision = np.result_type(A.dtype, B.dtype)
# Cast to float32 dtype for dtypes that are not float
if np.issubdtype(precision, float)==0:
precision = np.float32
return precision
评论列表
文章目录