def gen_batch_in_thread(img_map, df_cap, vocab_size, n_jobs=4,
size_per_thread=32):
imgs, curs, nxts, seqs, vhists = [], [], [], [], []
returns = Parallel(n_jobs=4, backend='threading')(
delayed(generate_batch)
(img_train, df_cap, vocab_size, size=size_per_thread)
for i in range(0, n_jobs))
for triple in returns:
imgs.extend(triple[0])
curs.extend(triple[1])
nxts.extend(triple[2])
seqs.extend(triple[3])
vhists.extend(triple[4])
return np.array(imgs), np.array(curs).reshape((-1, 1)), np.array(nxts), \
np.array(seqs), np.array(vhists)
评论列表
文章目录