def decoding_layer_infer(encoder_state, dec_cell, dec_embeddings, start_of_sequence_id,
end_of_sequence_id, max_target_sequence_length,
vocab_size, output_layer, batch_size, keep_prob):
start_tokens = tf.tile(
tf.constant([start_of_sequence_id], dtype=tf.int32), [batch_size], name='start_tokens')
# Define the helper
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
dec_embeddings,
start_tokens,
end_of_sequence_id)
# Define the decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
dec_cell,
helper,
encoder_state,
output_layer)
# Run the decoder
infer_decoder_output, _ = tf.contrib.seq2seq.dynamic_decode(
decoder,
impute_finished=True,
maximum_iterations=max_target_sequence_length)
return infer_decoder_output
评论列表
文章目录