def __call__(self, process_func):
def wrapper(*args):
data_obj = args[1]
if (len(data_obj.shape) <= self.input_dim
or data_obj.shape[-1] == 1):
return process_func(*args)
else:
pool = mp.Pool(mp.cpu_count())# TODO: make configurable
arglist = [
(args[0],) +
(data_obj[...,i],) +
args[2:]
for i in range(data_obj.shape[-1])
]
result = pool.map(self.worker, arglist)
if self.output_dim > self.input_dim: # expanding
return np.stack(result, -1)
else: # contracting
return np.concatenate(result, -1)
return wrapper
评论列表
文章目录