def upcast_int8_nfunc(fn):
"""Decorator that upcasts input of dtype int8 to float32.
This is so that floating-point computation is not carried using
half-precision (float16), as some NumPy functions do.
:param fn: function computing a floating-point value from inputs
:returns: function similar to fn, but upcasting its uint8 and int8
inputs before carrying out the computation.
"""
def ret(*args, **kwargs):
args = list(args)
for i, a in enumerate(args):
if getattr(a, 'dtype', None) in ('int8', 'uint8'):
args[i] = a.astype('float32')
return fn(*args, **kwargs)
return ret
评论列表
文章目录