def birnn(cell, inputs, sequence_length, initial_state_fw=None, initial_state_bw=None, ff_keep_prob=1., recur_keep_prob=1., dtype=tf.float32, scope=None):
""""""
# Forward direction
with tf.variable_scope(scope or 'BiRNN_FW') as fw_scope:
output_fw, output_state_fw = rnn(cell, inputs, sequence_length, initial_state_fw, ff_keep_prob, recur_keep_prob, dtype, scope=fw_scope)
# Backward direction
rev_inputs = tf.reverse_sequence(inputs, sequence_length, 1, 0)
with tf.variable_scope(scope or 'BiRNN_BW') as bw_scope:
output_bw, output_state_bw = rnn(cell, rev_inputs, sequence_length, initial_state_bw, ff_keep_prob, recur_keep_prob, dtype, scope=bw_scope)
output_bw = tf.reverse_sequence(output_bw, sequence_length, 1, 0)
# Concat each of the forward/backward outputs
outputs = tf.concat([output_fw, output_bw], 2)
return outputs, tf.tuple([output_state_fw, output_state_bw])
#===============================================================
评论列表
文章目录