def rnn_layer(rnn_input: tf.Tensor, lengths: tf.Tensor,
rnn_spec: RNNSpec) -> Tuple[tf.Tensor, tf.Tensor]:
"""Construct a RNN layer given its inputs and specs.
Arguments:
rnn_inputs: The input sequence to the RNN.
lengths: Lengths of input sequences.
rnn_spec: A valid RNNSpec tuple specifying the network architecture.
"""
if rnn_spec.direction == "bidirectional":
fw_cell = _make_rnn_cell(rnn_spec)
bw_cell = _make_rnn_cell(rnn_spec)
outputs_tup, states_tup = tf.nn.bidirectional_dynamic_rnn(
fw_cell, bw_cell, rnn_input, sequence_length=lengths,
dtype=tf.float32)
outputs = tf.concat(outputs_tup, 2)
if rnn_spec.cell_type == "LSTM":
states_tup = (state.h for state in states_tup)
final_state = tf.concat(list(states_tup), 1)
else:
if rnn_spec.direction == "backward":
rnn_input = tf.reverse_sequence(rnn_input, lengths, seq_axis=1)
cell = _make_rnn_cell(rnn_spec)
outputs, final_state = tf.nn.dynamic_rnn(
cell, rnn_input, sequence_length=lengths, dtype=tf.float32)
if rnn_spec.direction == "backward":
outputs = tf.reverse_sequence(outputs, lengths, seq_axis=1)
if rnn_spec.cell_type == "LSTM":
final_state = final_state.h
return outputs, final_state
评论列表
文章目录