def process_decoding_input(target_data, target_vocab_to_int, batch_size):
l_word = tf.strided_slice(target_data, [0, 0], [batch_size, -1], [1, 1])
return tf.concat([tf.fill([batch_size, 1], target_vocab_to_int['<GO>']), l_word], 1)
process_inputs.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录