def select_next_words(self, next_costs, next_probs, step_num, how_many):
# Pick only on the first line (for the beginning of sampling)
# This will avoid duplicate <q> token.
if step_num == 0:
flat_next_costs = next_costs[:1, :].flatten()
else:
# Set the next cost to infinite for finished utterances (they will be replaced)
# by other utterances in the beam
flat_next_costs = next_costs.flatten()
voc_size = next_costs.shape[1]
args = numpy.argpartition(flat_next_costs, how_many)[:how_many]
args = args[numpy.argsort(flat_next_costs[args])]
return numpy.unravel_index(args, next_costs.shape), flat_next_costs[args]
评论列表
文章目录