python类njit()的实例源码

aggregate_numba.py 文件源码 项目:mobula 作者: wkcn 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def _2pass(cls):
        _2pass_inner = nb.njit(cls._2pass_inner)
        def _2pass_loop(ret, counter, mean, fill_value, ddof):
            for ri in range(len(ret)):
                if counter[ri]:
                    ret[ri] = _2pass_inner(ri, ret, counter, mean, ddof)
                else:
                    ret[ri] = fill_value
        return _2pass_loop
aggregate_numba.py 文件源码 项目:mobula 作者: wkcn 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def callable(cls, nans=False, reverse=False, scalar=False):
        """ Compile a jitted function doing the hard part of the job """
        if scalar:
            def _valgetter(a, i):
                return a
        else:
            def _valgetter(a, i):
                return a[i]
        valgetter = nb.njit(_valgetter)

        if nans:
            def _ri_redir(i, val):
                """ Redirect any write access to the output array to it's 
                    first field, if we encounter a nan value. This first field
                    was reserved in advance for dummy access. Shift the index
                    by 1, if we don't have a nan value.
                """
                return (i + 1) * (val == val)
        else:
            def _ri_redir(i, val):
                return i
        ri_redir = nb.njit(_ri_redir)

        inner = _inner = nb.njit(cls._inner)

        def _loop(group_idx, a, ret, counter, mean, fill_value, ddof):
            rng = range(len(group_idx) - 1, -1 , -1) if reverse else range(len(group_idx))
            for i in rng:
                val = valgetter(a, i)
                ri = ri_redir(group_idx[i], val)
                inner(ri, val, ret, counter, mean)
        loop = nb.njit(_loop, nogil=True, cache=True)

        _2pass = cls._2pass()
        if _2pass is None:
            return loop

        _2pass = nb.njit(_2pass, nogil=True, cache=True)
        def loop_2pass(group_idx, a, ret, counter, mean, fill_value, ddof):
            loop(group_idx, a, ret, counter, mean, fill_value, ddof)
            _2pass(ret, counter, mean, fill_value, ddof)
        return loop_2pass


问题


面经


文章

微信
公众号

扫码关注公众号