seq2seq_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:deepsphinx 作者: vagrawal 项目源码 文件源码
def inference_decoding_layer(
        enc_output,
        enc_output_lengths,
        fst,
        keep_prob):
    ''' Inference decoding layer for the model.

    Returns:
        Predictions
    '''

    dec_cell = get_dec_cell(
        enc_output,
        enc_output_lengths,
        FLAGS.use_inference_lm,
        fst,
        FLAGS.beam_width,
        keep_prob)

    initial_state = dec_cell.zero_state(
        dtype=tf.float32,
        batch_size=FLAGS.batch_size * FLAGS.beam_width)

    start_tokens = tf.fill(
        [FLAGS.batch_size],
        VOCAB_TO_INT['<s>'],
        name='start_tokens')

    inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
        dec_cell,
        tf.eye(VOCAB_SIZE),
        start_tokens,
        VOCAB_TO_INT['</s>'],
        initial_state,
        FLAGS.beam_width)

    predictions, _, _ = tf.contrib.seq2seq.dynamic_decode(
        inference_decoder,
        output_time_major=False,
        maximum_iterations=FLAGS.max_output_len)

    return predictions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号