def bidirectional_GRU(inputs, inputs_len, cell = None, cell_fn = tf.contrib.rnn.GRUCell, units = Params.attn_size, layers = 1, scope = "Bidirectional_GRU", output = 0, is_training = True, reuse = None):
'''
Bidirectional recurrent neural network with GRU cells.
Args:
inputs: rnn input of shape (batch_size, timestep, dim)
inputs_len: rnn input_len of shape (batch_size, )
cell: rnn cell of type RNN_Cell.
output: if 0, output returns rnn output for every timestep,
if 1, output returns concatenated state of backward and
forward rnn.
'''
with tf.variable_scope(scope, reuse = reuse):
if cell is not None:
(cell_fw, cell_bw) = cell
else:
shapes = inputs.get_shape().as_list()
if len(shapes) > 3:
inputs = tf.reshape(inputs,(shapes[0]*shapes[1],shapes[2],-1))
inputs_len = tf.reshape(inputs_len,(shapes[0]*shapes[1],))
# if no cells are provided, use standard GRU cell implementation
if layers > 1:
cell_fw = MultiRNNCell([apply_dropout(cell_fn(units), size = inputs.shape[-1] if i == 0 else units, is_training = is_training) for i in range(layers)])
cell_bw = MultiRNNCell([apply_dropout(cell_fn(units), size = inputs.shape[-1] if i == 0 else units, is_training = is_training) for i in range(layers)])
else:
cell_fw, cell_bw = [apply_dropout(cell_fn(units), size = inputs.shape[-1], is_training = is_training) for _ in range(2)]
outputs, states = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs,
sequence_length = inputs_len,
dtype=tf.float32)
if output == 0:
return tf.concat(outputs, 2)
elif output == 1:
return tf.reshape(tf.concat(states,1),(Params.batch_size, shapes[1], 2*units))
评论列表
文章目录