def lower_dist_cumsum(context, builder, sig, args):
dtype = sig.args[0].dtype
zero = dtype(0)
def cumsum_impl(in_arr, out_arr):
c = zero
for v in np.nditer(in_arr):
c += v.item()
prefix_var = distributed_api.dist_exscan(c)
for i in range(in_arr.size):
prefix_var += in_arr[i]
out_arr[i] = prefix_var
return 0
res = context.compile_internal(builder, cumsum_impl, sig, args,
locals=dict(c=dtype,
prefix_var=dtype))
return res
评论列表
文章目录