def beam_search(self,initial_state):
'''
Beam search is a graph search algorithm! So I use graph search abstraction
Args:
initial state: an initial stete, python tuple (hx,cx,path,cost)
each state has
hx: hidden states
cx: cell states
path: word indicies so far as a python list e.g. initial is self.token2index["<sos>"]
cost: negative log likelihood
Returns:
captions sorted by the cost (i.e. negative log llikelihood)
'''
found_paths=[]
top_k_states=[initial_state]
while (len(found_paths) < self.beamsize):
#forward one step for all top k states, then only select top k after that
new_top_k_states=[]
for state in top_k_states:
#examine to next five possible states
hy, cy, k_best_next_states = self.successor(state)
for next_state in k_best_next_states:
new_top_k_states.append(next_state)
selected_top_k_states=heapq.nsmallest(self.beamsize, new_top_k_states, key=lambda x : x["cost"])
#within the selected states, let's check if it is terminal or not.
top_k_states=[]
for state in selected_top_k_states:
#is goal state? -> yes, then end the search
if state["path"][-1] == self.token2index["<eos>"] or len(state["path"])==self.depth_limit:
found_paths.append(state)
else:
top_k_states.append(state)
return sorted(found_paths, key=lambda x: x["cost"])
评论列表
文章目录