def training_decoding_layer(
target_data,
target_lengths,
enc_output,
enc_output_lengths,
fst,
keep_prob):
''' Training decoding layer for the model.
Returns:
Training logits
'''
target_data = tf.concat(
[tf.fill([FLAGS.batch_size, 1], VOCAB_TO_INT['<s>']),
target_data[:, :-1]], 1)
dec_cell = get_dec_cell(
enc_output,
enc_output_lengths,
FLAGS.use_train_lm,
fst,
1,
keep_prob)
initial_state = dec_cell.zero_state(
dtype=tf.float32,
batch_size=FLAGS.batch_size)
target_data = tf.nn.embedding_lookup(
tf.eye(VOCAB_SIZE),
target_data)
training_helper = tf.contrib.seq2seq.TrainingHelper(
inputs=target_data,
sequence_length=target_lengths,
time_major=False)
training_decoder = tf.contrib.seq2seq.BasicDecoder(
dec_cell,
training_helper,
initial_state)
training_logits, _, _ = tf.contrib.seq2seq.dynamic_decode(
training_decoder,
output_time_major=False,
impute_finished=True)
return training_logits
评论列表
文章目录