def inference_decoding_layer(
enc_output,
enc_output_lengths,
fst,
keep_prob):
''' Inference decoding layer for the model.
Returns:
Predictions
'''
dec_cell = get_dec_cell(
enc_output,
enc_output_lengths,
FLAGS.use_inference_lm,
fst,
FLAGS.beam_width,
keep_prob)
initial_state = dec_cell.zero_state(
dtype=tf.float32,
batch_size=FLAGS.batch_size * FLAGS.beam_width)
start_tokens = tf.fill(
[FLAGS.batch_size],
VOCAB_TO_INT['<s>'],
name='start_tokens')
inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
dec_cell,
tf.eye(VOCAB_SIZE),
start_tokens,
VOCAB_TO_INT['</s>'],
initial_state,
FLAGS.beam_width)
predictions, _, _ = tf.contrib.seq2seq.dynamic_decode(
inference_decoder,
output_time_major=False,
maximum_iterations=FLAGS.max_output_len)
return predictions
评论列表
文章目录