def _reverse_seq(input_seq, lengths):
"""Reverse a list of Tensors up to specified lengths.
Args:
input_seq: Sequence of seq_len tensors of dimension (batch_size, depth)
lengths: A tensor of dimension batch_size, containing lengths for each
sequence in the batch. If "None" is specified, simply reverses
the list.
Returns:
time-reversed sequence
"""
if lengths is None:
return list(reversed(input_seq))
input_shape = tensor_shape.matrix(None, None)
for input_ in input_seq:
input_shape.merge_with(input_.get_shape())
input_.set_shape(input_shape)
# Join into (time, batch_size, depth)
s_joined = tf.stack(input_seq)
if lengths is not None:
lengths = tf.to_int64(lengths)
# Reverse along dimension 0
s_reversed = tf.reverse_sequence(s_joined, lengths, 0, 1)
# Split again into list
result = tf.unstack(s_reversed)
for r in result:
r.set_shape(input_shape)
return result
评论列表
文章目录