def bidirectional_rnn(forward_cell, backward_cell, inputs, seq_lens_mask, concatenate=True):
seq_lens = tf.cast(tf.reduce_sum(seq_lens_mask, 1), tf.int32)
# Reverse inputs (batch x time x embedding_dim); takes care of variable seq_len
reverse_inputs = tf.reverse_sequence(inputs, seq_lens, seq_dim=1, batch_dim=0)
# Run forwards and backwards RNN
forward_outputs, forward_last_state = \
rnn(forward_cell, inputs, seq_lens_mask)
backward_outputs_reversed, backward_last_state = \
rnn(backward_cell, reverse_inputs, seq_lens_mask)
backward_outputs = tf.reverse_sequence(backward_outputs_reversed, seq_lens, seq_dim=1, batch_dim=0)
if concatenate:
# last_state dimensions: batch x hidden_size
last_state = tf.concat(1, [forward_last_state, backward_last_state])
# outputs dimensions: batch x time x hidden_size
outputs = tf.concat(2, [forward_outputs, backward_outputs])
# Dimensions: outputs (batch x time x hidden_size*2); last_state (batch x hidden_size*2)
return (outputs, last_state)
# Dimensions: outputs (batch x time x hidden_size); last_state (batch x hidden_size)
return (forward_outputs, forward_last_state, backward_outputs, backward_last_state)
评论列表
文章目录