beam.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:nmt 作者: Playinf 项目源码 文件源码
def find_nbest(score, n, threshold=None):
    num_vars = score.shape[1]

    score = score.flatten()
    nbest = np.argpartition(score, n)[:n]

    beam_indices = nbest / num_vars
    var_indices = nbest % num_vars
    nbest_score = score[nbest]

    if threshold:
        best = np.max(nbest_score)
        cond = nbest_score > best + threshold
        nbest_score = nbest_score[cond]
        beam_indices = beam_indices[cond]
        var_indices = var_indices[cond]

    return nbest_score, beam_indices, var_indices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号