def fused_birnn(fused_rnn, inputs, sequence_length, initial_state=(None, None), dtype=None, scope=None,
time_major=False, backward_device=None):
with tf.variable_scope(scope or "BiRNN"):
sequence_length = tf.cast(sequence_length, tf.int32)
if not time_major:
inputs = tf.transpose(inputs, [1, 0, 2])
outputs_fw, state_fw = fused_rnn(inputs, sequence_length=sequence_length, initial_state=initial_state[0],
dtype=dtype, scope="FW")
if backward_device is not None:
with tf.device(backward_device):
outputs_bw, state_bw = fused_rnn_backward(fused_rnn, inputs, sequence_length, initial_state[1], dtype,
scope="BW")
else:
outputs_bw, state_bw = fused_rnn_backward(fused_rnn, inputs, sequence_length, initial_state[1], dtype,
scope="BW")
if not time_major:
outputs_fw = tf.transpose(outputs_fw, [1, 0, 2])
outputs_bw = tf.transpose(outputs_bw, [1, 0, 2])
return (outputs_fw, outputs_bw), (state_fw, state_bw)
评论列表
文章目录