def _decoding_loop(self) -> BeamSearchOutput:
# collect attention objects
beam_body = self.get_body()
initial_loop_state = self.get_initial_loop_state()
def cond(*args) -> tf.Tensor:
bsls = BeamSearchLoopState(*args)
return tf.less(
bsls.decoder_loop_state.feedables.step - 1, self._max_steps)
# First step has to be run manually because while_loop needs the same
# shapes between steps and the first beam state is not beam-sized, but
# just a single state.
#
# When running ensembles, we want to provide
# ensembled logprobs to the beam_body before manually running
# the first step
next_bs_loop_state = tf.cond(
cond(*initial_loop_state),
lambda: beam_body(*initial_loop_state),
lambda: initial_loop_state)
final_state = tf.while_loop(cond, beam_body, next_bs_loop_state)
dec_loop_state = final_state.decoder_loop_state
bs_state = final_state.bs_state
scores = final_state.bs_output.scores.stack()
parent_ids = final_state.bs_output.parent_ids.stack()
token_ids = final_state.bs_output.token_ids.stack()
# TODO: return att_loop_states properly
return BeamSearchOutput(
last_search_step_output=SearchStepOutput(
scores=scores,
parent_ids=parent_ids,
token_ids=token_ids),
last_dec_loop_state=dec_loop_state.feedables,
last_search_state=bs_state,
attention_loop_states=[])
评论列表
文章目录