train.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:image_caption 作者: MaticsL 项目源码 文件源码
def gen_batch_in_thread(img_map, df_cap, vocab_size, n_jobs=4,
                        size_per_thread=32):
    imgs, curs, nxts, seqs, vhists = [], [], [], [], []
    returns = Parallel(n_jobs=4, backend='threading')(
                            delayed(generate_batch)
                            (img_train, df_cap, vocab_size, size=size_per_thread)
                            for i in range(0, n_jobs))

    for triple in returns:
        imgs.extend(triple[0])
        curs.extend(triple[1])
        nxts.extend(triple[2])
        seqs.extend(triple[3])
        vhists.extend(triple[4])

    return np.array(imgs), np.array(curs).reshape((-1, 1)), np.array(nxts), \
        np.array(seqs), np.array(vhists)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号