def advance(self, word_lk):
"Update the status and check for finished or not."
num_words = word_lk.size(1)
# Sum the previous scores.
if len(self.prev_ks) > 0:
beam_lk = word_lk + self.scores.unsqueeze(1).expand_as(word_lk)
else:
beam_lk = word_lk[0]
flat_beam_lk = beam_lk.view(-1)
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort
self.all_scores.append(self.scores)
self.scores = best_scores
# bestScoresId is flattened beam x word array, so calculate which
# word and beam each score came from
prev_k = best_scores_id / num_words
self.prev_ks.append(prev_k)
self.next_ys.append(best_scores_id - prev_k * num_words)
# End condition is when top-of-beam is EOS.
if self.next_ys[-1][0] == Constants.EOS:
self.done = True
self.all_scores.append(self.scores)
return self.done
Beam.py 文件源码
python
阅读 47
收藏 0
点赞 0
评论 0
评论列表
文章目录