def dump_source_translation(model, source_buckets, vocab_inv_source, vocab_inv_target, beam_width=8, normalization_alpha=0):
for source_bucket in source_buckets:
if beam_width == 1: # greedy
batchsize = 24
if len(source_bucket) > batchsize:
num_sections = len(source_bucket) // batchsize - 1
if len(source_bucket) % batchsize > 0:
num_sections += 1
indices = [(i + 1) * batchsize for i in range(num_sections)]
source_sections = np.split(source_bucket, indices, axis=0)
else:
source_sections = [source_bucket]
for source_batch in source_sections:
translation_batch = translate_greedy(model, source_batch, source_batch.shape[1] * 2, len(vocab_inv_target), beam_width)
for index in range(len(translation_batch)):
source = source_batch[index]
translation = translation_batch[index]
dump_translation(vocab_inv_source, vocab_inv_target, source, translation)
else: # beam search
for index in range(len(source_bucket)):
source = source_bucket[index]
translations = translate_beam_search(model, source, source.size * 2, len(vocab_inv_target), beam_width, normalization_alpha, return_all_candidates=True)
dump_all_translation(vocab_inv_source, vocab_inv_target, source, translations)
评论列表
文章目录