def upcast_float16_ufunc(fn):
"""Decorator that enforces computation is not done in float16 by NumPy.
Some ufuncs in NumPy will compute float values on int8 and uint8
in half-precision (float16), which is not enough, and not compatible
with the C code.
:param fn: numpy ufunc
:returns: function similar to fn.__call__, computing the same
value with a minimum floating-point precision of float32
"""
def ret(*args, **kwargs):
out_dtype = numpy.find_common_type(
[a.dtype for a in args], [numpy.float16])
if out_dtype == 'float16':
# Force everything to float32
sig = 'f' * fn.nin + '->' + 'f' * fn.nout
kwargs.update(sig=sig)
return fn(*args, **kwargs)
return ret
评论列表
文章目录