seq2seq_tf.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:PyTorchDemystified 作者: hhsecond 项目源码 文件源码
def decoding_layer_infer(encoder_state, dec_cell, dec_embeddings, start_of_sequence_id,
                         end_of_sequence_id, max_target_sequence_length,
                         vocab_size, output_layer, batch_size, keep_prob):
    start_tokens = tf.tile(
        tf.constant([start_of_sequence_id], dtype=tf.int32), [batch_size], name='start_tokens')
    # Define the helper
    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
        dec_embeddings,
        start_tokens,
        end_of_sequence_id)
    # Define the decoder
    decoder = tf.contrib.seq2seq.BasicDecoder(
        dec_cell,
        helper,
        encoder_state,
        output_layer)
    # Run the decoder
    infer_decoder_output, _ = tf.contrib.seq2seq.dynamic_decode(
        decoder,
        impute_finished=True,
        maximum_iterations=max_target_sequence_length)
    return infer_decoder_output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号