length.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:sgnmt 作者: ucam-smt 项目源码 文件源码
def predict_next(self):
        """Looks up ngram scores via self.scores. """
        cur_hist_length = len(self.history)
        this_scores = [[] for _ in xrange(cur_hist_length+1)]
        this_unk_scores = [[] for _ in xrange(cur_hist_length+1)]
        for pos in xrange(len(self.scores)):
            this_scores[0].append(self.scores[pos])
            this_unk_scores[0].append(self.unk_scores[pos])
            acc = 0.0
            for order, word in enumerate(self.history):
                if pos + order + 1 >= len(self.scores):
                    break
                acc += utils.common_get(
                    self.scores[pos + order], word, 
                    self.unk_scores[pos + order])
                this_scores[order+1].append(acc + self.scores[pos + order + 1])
                this_unk_scores[order+1].append(
                    acc + self.unk_scores[pos + order + 1])
        combined_scores = []
        combined_unk_scores = []
        for order, (scores, unk_scores) in enumerate(zip(this_scores, 
                                                         this_unk_scores)):
            if scores and order + 1 >= self.min_order:
                score_matrix = np.vstack(scores)
                combined_scores.append(logsumexp(score_matrix, axis=0))
                combined_unk_scores.append(utils.log_sum(unk_scores))
        if not combined_scores:
            self.cur_unk_score = 0.0
            return {}
        self.cur_unk_score = sum(combined_unk_scores)
        return sum(combined_scores)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号