distributed.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:hpat 作者: IntelLabs 项目源码 文件源码
def _gen_init_reduce(self, reduce_var, reduce_op):
        """generate code to initialize reduction variables on non-root
        processors.
        """
        red_var_typ = self.typemap[reduce_var.name]
        el_typ = red_var_typ
        if self._isarray(reduce_var.name):
            el_typ = red_var_typ.dtype
        init_val = None
        pre_init_val = ""

        if reduce_op == Reduce_Type.Sum:
            init_val = str(el_typ(0))
        if reduce_op == Reduce_Type.Prod:
            init_val = str(el_typ(1))
        if reduce_op == Reduce_Type.Min:
            init_val = "numba.targets.builtins.get_type_max_value(np.ones(1,dtype=np.{}).dtype)".format(el_typ)
        if reduce_op == Reduce_Type.Max:
            init_val = "numba.targets.builtins.get_type_min_value(np.ones(1,dtype=np.{}).dtype)".format(el_typ)
        if reduce_op in [Reduce_Type.Argmin, Reduce_Type.Argmax]:
            # don't generate initialization for argmin/argmax since they are not
            # initialized by user and correct initialization is already there
            return []

        assert init_val is not None
        #import pdb; pdb.set_trace()

        if self._isarray(reduce_var.name):
            pre_init_val = "v = np.full_like(s, {}, s.dtype)".format(init_val)
            init_val = "v"

        f_text = "def f(s):\n  {}\n  s = hpat.distributed_lower._root_rank_select(s, {})".format(pre_init_val, init_val)
        loc_vars = {}
        exec(f_text, {}, loc_vars)
        f = loc_vars['f']

        f_block = compile_to_numba_ir(f, {'hpat': hpat, 'numba': numba, 'np': np},
                self.typingctx, (red_var_typ,), self.typemap, self.calltypes).blocks.popitem()[1]
        replace_arg_nodes(f_block, [reduce_var])
        nodes = f_block.body[:-3]
        nodes[-1].target = reduce_var
        return nodes
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号