def process_decoder_input(target_data, target_vocab_to_int, batch_size):
# Take off the last column
sliced = tf.strided_slice(target_data, [0, 0], [batch_size, -1], [1, 1])
# Append a column filled with <GO>
decoder_input = tf.concat([tf.fill([batch_size, 1], target_vocab_to_int['<GO>']), sliced], 1)
return decoder_input
评论列表
文章目录