def _translate(seq, f_init, f_next, trg_eos_idx, src_sel, trg_sel,
k, cond_init_trg, normalize, n_best, **kwargs):
sample, score = gen_sample(
f_init, f_next, x=numpy.array(seq).reshape([len(seq), 1]),
eos_idx=trg_eos_idx, src_selector=src_sel, trg_selector=trg_sel,
k=k, maxlen=3*len(seq), stochastic=False, argmax=False,
cond_init_trg=cond_init_trg, **kwargs)
if normalize:
lengths = numpy.array([len(s) for s in sample])
score = score / lengths
if n_best == 1:
sidx = numpy.argmin(score)
elif n_best > 1:
sidx = numpy.argsort(score)[:n_best]
else:
raise ValueError('n_best cannot be negative!')
return sample[sidx], score[sidx]
评论列表
文章目录