loop_blocking.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:nn_dataflow 作者: stanford-mast 项目源码 文件源码
def _gen_loopblocking_perprocess(
        nested_loop_desc, resource, cost, part_occ, options,
        gen_tifm, gen_tofm, gen_tbat, gen_ords):

    def _gen_bl_ts():
        '''
        Generator for blocking factors.

        Transpose LoopEnum-major to BL-major.
        '''
        gen_lp_ts = [None] * le.NUM
        gen_lp_ts[le.IFM] = gen_tifm
        gen_lp_ts[le.OFM] = gen_tofm
        gen_lp_ts[le.BAT] = gen_tbat
        for lp_ts in itertools.product(*gen_lp_ts):
            bl_ts = tuple(zip(*lp_ts))
            yield bl_ts

    def _sweep():
        ''' Sweep all. '''
        is_conv_loops = _is_conv_loops(nested_loop_desc)
        for bl_ts, bl_ords in itertools.product(_gen_bl_ts(), gen_ords):
            if is_conv_loops and skip_conv(bl_ts, bl_ords):
                continue
            lbs = LoopBlockingScheme(
                nested_loop_desc, bl_ts, bl_ords, resource, part_occ, options)
            yield lbs

    return heapq.nsmallest(options.ntops, _sweep(),
                           key=lambda lbs: lbs.get_cost(cost))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号