parallel.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:untwist 作者: IoSR-Surrey 项目源码 文件源码
def __call__(self, process_func):

        def wrapper(*args):
            data_obj = args[1]
            if (len(data_obj.shape) <= self.input_dim
                or data_obj.shape[-1] == 1):
                return process_func(*args)
            else:
                pool = mp.Pool(mp.cpu_count())# TODO: make configurable
                arglist = [
                    (args[0],) +
                    (data_obj[...,i],) +
                    args[2:]
                    for i in range(data_obj.shape[-1])
                ]
                result = pool.map(self.worker, arglist)
                if self.output_dim > self.input_dim: # expanding
                    return np.stack(result, -1)
                else: # contracting
                    return np.concatenate(result, -1)
        return wrapper
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号