translate.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainer-qrnn 作者: musyoku 项目源码 文件源码
def dump_source_translation(model, source_buckets, vocab_inv_source, vocab_inv_target, beam_width=8, normalization_alpha=0):
    for source_bucket in source_buckets:
        if beam_width == 1: # greedy
            batchsize = 24
            if len(source_bucket) > batchsize:
                num_sections = len(source_bucket) // batchsize - 1
                if len(source_bucket) % batchsize > 0:
                    num_sections += 1
                indices = [(i + 1) * batchsize for i in range(num_sections)]
                source_sections = np.split(source_bucket, indices, axis=0)
            else:
                source_sections = [source_bucket]

            for source_batch in source_sections:
                translation_batch = translate_greedy(model, source_batch, source_batch.shape[1] * 2, len(vocab_inv_target), beam_width)
                for index in range(len(translation_batch)):
                    source = source_batch[index]
                    translation = translation_batch[index]
                    dump_translation(vocab_inv_source, vocab_inv_target, source, translation)
        else:   # beam search
            for index in range(len(source_bucket)):
                source = source_bucket[index]
                translations = translate_beam_search(model, source, source.size * 2, len(vocab_inv_target), beam_width, normalization_alpha, return_all_candidates=True)
                dump_all_translation(vocab_inv_source, vocab_inv_target, source, translations)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号