convolution.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def __init__(self, lib, dtype,
                 N, C, K,
                 D, H, W,
                 T, R, S,
                 M, P, Q,
                 pad_d, pad_h, pad_w,
                 str_d, str_h, str_w,
                 dil_d, dil_h, dil_w):

        assert N % 4 == 0, "N dim must be multiple of 4"

        super(UpdateDirect, self).__init__(lib, dtype,
            N, C, K, D, H, W, T, R, S, M, P, Q,
            pad_d, pad_h, pad_w, str_d, str_h, str_w,
            dil_d, dil_h, dil_w)

        SMs = _get_sm_count()

        self.autotune_key = " ".join(native_str(x) for x in (
            "direct_updat_64x32", SMs, dtype.itemsize, lib.deterministic > 0,
            N, C, K, D, H, W, T, R, S, M, P, Q ))

        # insert Python version in filename to avoid Py2/Py3 incompatibilities in shelve
        self.autotune_db_file = os.path.join(lib.cache_dir, "autotune%d.db" % sys.version_info[0])
        self.init()

        lib.set_scratch_size(self.output_trans.size)

        # allow for .5 seconds worth of warmup when autotuning
        # assume 5 Tflops on 24 SMs
        self.warmup = min(max(int(2e12 / (M * P * Q * K * N * C * T * R * S * 2.0) * (SMs / 24.0)), 1), 5000)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号