seq2seq_tf.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:PyTorchDemystified 作者: hhsecond 项目源码 文件源码
def process_decoder_input(target_data, target_vocab_to_int, batch_size):
    # Take off the last column
    sliced = tf.strided_slice(target_data, [0, 0], [batch_size, -1], [1, 1])
    # Append a column filled with <GO>
    decoder_input = tf.concat([tf.fill([batch_size, 1], target_vocab_to_int['<GO>']), sliced], 1)
    return decoder_input
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号