recurrent.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:neuralmonkey 作者: ufal 项目源码 文件源码
def rnn_layer(rnn_input: tf.Tensor, lengths: tf.Tensor,
              rnn_spec: RNNSpec) -> Tuple[tf.Tensor, tf.Tensor]:
    """Construct a RNN layer given its inputs and specs.

    Arguments:
        rnn_inputs: The input sequence to the RNN.
        lengths: Lengths of input sequences.
        rnn_spec: A valid RNNSpec tuple specifying the network architecture.
    """
    if rnn_spec.direction == "bidirectional":
        fw_cell = _make_rnn_cell(rnn_spec)
        bw_cell = _make_rnn_cell(rnn_spec)

        outputs_tup, states_tup = tf.nn.bidirectional_dynamic_rnn(
            fw_cell, bw_cell, rnn_input, sequence_length=lengths,
            dtype=tf.float32)

        outputs = tf.concat(outputs_tup, 2)

        if rnn_spec.cell_type == "LSTM":
            states_tup = (state.h for state in states_tup)

        final_state = tf.concat(list(states_tup), 1)
    else:
        if rnn_spec.direction == "backward":
            rnn_input = tf.reverse_sequence(rnn_input, lengths, seq_axis=1)

        cell = _make_rnn_cell(rnn_spec)
        outputs, final_state = tf.nn.dynamic_rnn(
            cell, rnn_input, sequence_length=lengths, dtype=tf.float32)

        if rnn_spec.direction == "backward":
            outputs = tf.reverse_sequence(outputs, lengths, seq_axis=1)

        if rnn_spec.cell_type == "LSTM":
            final_state = final_state.h

    return outputs, final_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号