test_basic.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def upcast_float16_ufunc(fn):
    """Decorator that enforces computation is not done in float16 by NumPy.

    Some ufuncs in NumPy will compute float values on int8 and uint8
    in half-precision (float16), which is not enough, and not compatible
    with the C code.

    :param fn: numpy ufunc
    :returns: function similar to fn.__call__, computing the same
        value with a minimum floating-point precision of float32
    """
    def ret(*args, **kwargs):
        out_dtype = numpy.find_common_type(
            [a.dtype for a in args], [numpy.float16])
        if out_dtype == 'float16':
            # Force everything to float32
            sig = 'f' * fn.nin + '->' + 'f' * fn.nout
            kwargs.update(sig=sig)
        return fn(*args, **kwargs)

    return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号