test_basic.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def upcast_int8_nfunc(fn):
    """Decorator that upcasts input of dtype int8 to float32.

    This is so that floating-point computation is not carried using
    half-precision (float16), as some NumPy functions do.

    :param fn: function computing a floating-point value from inputs
    :returns: function similar to fn, but upcasting its uint8 and int8
        inputs before carrying out the computation.
    """
    def ret(*args, **kwargs):
        args = list(args)
        for i, a in enumerate(args):
            if getattr(a, 'dtype', None) in ('int8', 'uint8'):
                args[i] = a.astype('float32')

        return fn(*args, **kwargs)

    return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号