def search(lattice, queues, rnn_predictor, ngrams, beam_size, viterbi_size):
# Breadth first search with beam pruning and viterbi-like pruning
for i in range(len(lattice)):
queue = []
# create hypotheses without predicting next word
for j in range(len(lattice[i])):
for target, source, word_id in lattice[i][j]:
word_queue = []
for previous_cost, previous_history, previous_state, previous_prediction in queues[j]:
history = previous_history + [(target, source)]
cost = previous_cost + interpolate(previous_prediction[word_id], get_ngram_cost(ngrams, history))
# Temporal hypothesis is tuple of (cost, history, word_id, previous_state)
# Lazy prediction replaces word_id and previous_state to state and prediction
hypothesis = (cost, history, word_id, previous_state)
word_queue.append(hypothesis)
# prune word_queue to viterbi size
if viterbi_size > 0:
word_queue = heapq.nsmallest(viterbi_size, word_queue, key=operator.itemgetter(0))
queue += word_queue
# prune queue to beam size
if beam_size > 0:
queue = heapq.nsmallest(beam_size, queue, key=operator.itemgetter(0))
# predict next word and state before continue
for cost, history, word_id, previous_state in queue:
predictions, states = rnn_predictor.predict([word_id], [previous_state])
hypothesis = (cost, history, states[0], predictions[0])
queues[i].append(hypothesis)
return queues
评论列表
文章目录