def __init__(self,
name: str,
parent_decoder: Decoder,
beam_size: int,
length_normalization: float,
max_steps: int = None,
save_checkpoint: str = None,
load_checkpoint: str = None) -> None:
check_argument_types()
ModelPart.__init__(self, name, save_checkpoint, load_checkpoint)
self.parent_decoder = parent_decoder
self._beam_size = beam_size
self._length_normalization = length_normalization
# In the n+1th step, outputs of lenght n will be collected
# and the n+1th step of decoder (which is discarded) will be executed
if max_steps is None:
max_steps = parent_decoder.max_output_len
self._max_steps = tf.constant(max_steps + 1)
self.max_output_len = max_steps
# Feedables
self._search_state = None # type: SearchState
self._decoder_state = None # type: NamedTuple
# Output
self.outputs = self._decoding_loop()
评论列表
文章目录