def fused_rnn_backward(fused_rnn, inputs, sequence_length, initial_state=None, dtype=None, scope=None, time_major=True):
if not time_major:
inputs = tf.transpose(inputs, [1, 0, 2])
# assumes that time dim is 0 and batch is 1
rev_inputs = tf.reverse_sequence(inputs, sequence_length, 0, 1)
rev_outputs, last_state = fused_rnn(rev_inputs, sequence_length=sequence_length, initial_state=initial_state,
dtype=dtype, scope=scope)
outputs = tf.reverse_sequence(rev_outputs, sequence_length, 0, 1)
if not time_major:
outputs = tf.transpose(outputs, [1, 0, 2])
return outputs, last_state
评论列表
文章目录