def _decode_infer(self, decoder, bridge, _encoder_output, features, labels):
"""Runs decoding in inference mode"""
batch_size = self.batch_size(features, labels)
if self.use_beam_search:
batch_size = self.params["inference.beam_search.beam_width"]
target_start_id = self.target_vocab_info.special_vocab.SEQUENCE_START
helper_infer = tf_decode_helper.GreedyEmbeddingHelper(
embedding=self.target_embedding,
start_tokens=tf.fill([batch_size], target_start_id),
end_token=self.target_vocab_info.special_vocab.SEQUENCE_END)
decoder_initial_state = bridge()
return decoder(decoder_initial_state, helper_infer)
评论列表
文章目录