rnn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:reading-comprehension 作者: kellywzhang 项目源码 文件源码
def bidirectional_rnn(forward_cell, backward_cell, inputs, seq_lens_mask, concatenate=True):
    seq_lens = tf.cast(tf.reduce_sum(seq_lens_mask, 1), tf.int32)

    # Reverse inputs (batch x time x embedding_dim); takes care of variable seq_len
    reverse_inputs = tf.reverse_sequence(inputs, seq_lens, seq_dim=1, batch_dim=0)

    # Run forwards and backwards RNN
    forward_outputs, forward_last_state = \
        rnn(forward_cell, inputs, seq_lens_mask)
    backward_outputs_reversed, backward_last_state = \
        rnn(backward_cell, reverse_inputs, seq_lens_mask)

    backward_outputs = tf.reverse_sequence(backward_outputs_reversed, seq_lens, seq_dim=1, batch_dim=0)

    if concatenate:
        # last_state dimensions: batch x hidden_size
        last_state = tf.concat(1, [forward_last_state, backward_last_state])
        # outputs dimensions: batch x time x hidden_size
        outputs = tf.concat(2, [forward_outputs, backward_outputs])

        # Dimensions: outputs (batch x time x hidden_size*2); last_state (batch x hidden_size*2)
        return (outputs, last_state)

    # Dimensions: outputs (batch x time x hidden_size); last_state (batch x hidden_size)
    return (forward_outputs, forward_last_state, backward_outputs, backward_last_state)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号