def get_initial_loop_state(self) -> BeamSearchLoopState:
# TODO make these feedable
output_ta = SearchStepOutputTA(
scores=tf.TensorArray(dtype=tf.float32, dynamic_size=True,
size=0, name="beam_scores"),
parent_ids=tf.TensorArray(dtype=tf.int32, dynamic_size=True,
size=0, name="beam_parents"),
token_ids=tf.TensorArray(dtype=tf.int32, dynamic_size=True,
size=0, name="beam_tokens"))
# We run the decoder once to get logits for ensembling
dec_ls = self.parent_decoder.get_initial_loop_state()
decoder_body = self.parent_decoder.get_body(False)
dec_ls = decoder_body(*dec_ls)
# We want to feed these values in ensembles
self._search_state = SearchState(
logprob_sum=tf.placeholder_with_default([0.0], [None]),
prev_logprobs=tf.nn.log_softmax(dec_ls.feedables.prev_logits),
lengths=tf.placeholder_with_default([1], [None]),
finished=tf.placeholder_with_default([False], [None]))
self._decoder_state = dec_ls.feedables
# TODO make TensorArrays also feedable
return BeamSearchLoopState(
bs_state=self._search_state,
bs_output=output_ta,
decoder_loop_state=dec_ls)
评论列表
文章目录