def jobmap(func, INPUT_ITR, FLAG_PARALLEL=False, batch_size=None,
*args, **kwargs):
n_jobs = -1 if FLAG_PARALLEL else 1
dfunc = joblib.delayed(func)
with joblib.Parallel(n_jobs=n_jobs) as MP:
# Yield the whole thing if there isn't a batch_size
if batch_size is None:
for z in MP(dfunc(x, *args, **kwargs)
for x in INPUT_ITR):
yield z
raise StopIteration
ITR = iter(INPUT_ITR)
progress_bar = tqdm()
for block in grouper(ITR, batch_size):
MPITR = MP(dfunc(x, *args, **kwargs) for x in block)
for k,z in enumerate(MPITR):
yield z
progress_bar.update(k+1)
评论列表
文章目录