def _gen_loopblocking_perprocess(
nested_loop_desc, resource, cost, part_occ, options,
gen_tifm, gen_tofm, gen_tbat, gen_ords):
def _gen_bl_ts():
'''
Generator for blocking factors.
Transpose LoopEnum-major to BL-major.
'''
gen_lp_ts = [None] * le.NUM
gen_lp_ts[le.IFM] = gen_tifm
gen_lp_ts[le.OFM] = gen_tofm
gen_lp_ts[le.BAT] = gen_tbat
for lp_ts in itertools.product(*gen_lp_ts):
bl_ts = tuple(zip(*lp_ts))
yield bl_ts
def _sweep():
''' Sweep all. '''
is_conv_loops = _is_conv_loops(nested_loop_desc)
for bl_ts, bl_ords in itertools.product(_gen_bl_ts(), gen_ords):
if is_conv_loops and skip_conv(bl_ts, bl_ords):
continue
lbs = LoopBlockingScheme(
nested_loop_desc, bl_ts, bl_ords, resource, part_occ, options)
yield lbs
return heapq.nsmallest(options.ntops, _sweep(),
key=lambda lbs: lbs.get_cost(cost))
评论列表
文章目录