def _gen_init_reduce(self, reduce_var, reduce_op):
"""generate code to initialize reduction variables on non-root
processors.
"""
red_var_typ = self.typemap[reduce_var.name]
el_typ = red_var_typ
if self._isarray(reduce_var.name):
el_typ = red_var_typ.dtype
init_val = None
pre_init_val = ""
if reduce_op == Reduce_Type.Sum:
init_val = str(el_typ(0))
if reduce_op == Reduce_Type.Prod:
init_val = str(el_typ(1))
if reduce_op == Reduce_Type.Min:
init_val = "numba.targets.builtins.get_type_max_value(np.ones(1,dtype=np.{}).dtype)".format(el_typ)
if reduce_op == Reduce_Type.Max:
init_val = "numba.targets.builtins.get_type_min_value(np.ones(1,dtype=np.{}).dtype)".format(el_typ)
if reduce_op in [Reduce_Type.Argmin, Reduce_Type.Argmax]:
# don't generate initialization for argmin/argmax since they are not
# initialized by user and correct initialization is already there
return []
assert init_val is not None
#import pdb; pdb.set_trace()
if self._isarray(reduce_var.name):
pre_init_val = "v = np.full_like(s, {}, s.dtype)".format(init_val)
init_val = "v"
f_text = "def f(s):\n {}\n s = hpat.distributed_lower._root_rank_select(s, {})".format(pre_init_val, init_val)
loc_vars = {}
exec(f_text, {}, loc_vars)
f = loc_vars['f']
f_block = compile_to_numba_ir(f, {'hpat': hpat, 'numba': numba, 'np': np},
self.typingctx, (red_var_typ,), self.typemap, self.calltypes).blocks.popitem()[1]
replace_arg_nodes(f_block, [reduce_var])
nodes = f_block.body[:-3]
nodes[-1].target = reduce_var
return nodes
评论列表
文章目录