def decoding_loop(self, train_mode: bool, sample: bool = False) -> Tuple[
tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
"""Run the decoding while loop.
Calls get_initial_loop_state and constructs tf.while_loop
with the continuation criterion returned from loop_continue_criterion,
and body function returned from get_body.
After finishing the tf.while_loop, it calls finalize_loop
to further postprocess the final decoder loop state (usually
by stacking TensorArrays containing decoding histories).
Arguments:
train_mode: Boolean flag, telling whether this is
a training run.
sample: Boolean flag, telling whether we should sample
the output symbols from the output distribution instead
of using argmax or gold data.
"""
initial_loop_state = self.get_initial_loop_state()
final_loop_state = tf.while_loop(
self.loop_continue_criterion,
self.get_body(train_mode, sample),
initial_loop_state)
self.finalize_loop(final_loop_state, train_mode)
logits = final_loop_state.histories.logits.stack()
decoder_outputs = final_loop_state.histories.decoder_outputs.stack()
decoded = final_loop_state.histories.outputs.stack()
# TODO mask should include also the end symbol
mask = final_loop_state.histories.mask.stack()
return logits, decoder_outputs, mask, decoded
评论列表
文章目录