seq2seq_model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:deepsphinx 作者: vagrawal 项目源码 文件源码
def training_decoding_layer(
        target_data,
        target_lengths,
        enc_output,
        enc_output_lengths,
        fst,
        keep_prob):
    ''' Training decoding layer for the model.

    Returns:
        Training logits
    '''
    target_data = tf.concat(
        [tf.fill([FLAGS.batch_size, 1], VOCAB_TO_INT['<s>']),
         target_data[:, :-1]], 1)

    dec_cell = get_dec_cell(
        enc_output,
        enc_output_lengths,
        FLAGS.use_train_lm,
        fst,
        1,
        keep_prob)

    initial_state = dec_cell.zero_state(
        dtype=tf.float32,
        batch_size=FLAGS.batch_size)

    target_data = tf.nn.embedding_lookup(
        tf.eye(VOCAB_SIZE),
        target_data)

    training_helper = tf.contrib.seq2seq.TrainingHelper(
        inputs=target_data,
        sequence_length=target_lengths,
        time_major=False)

    training_decoder = tf.contrib.seq2seq.BasicDecoder(
        dec_cell,
        training_helper,
        initial_state)

    training_logits, _, _ = tf.contrib.seq2seq.dynamic_decode(
        training_decoder,
        output_time_major=False,
        impute_finished=True)

    return training_logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号