def seq2nonseq(tensorlist, seq_length, name=None):
'''
Convert sequential data to non sequential data
Args:
tensorlist: the sequential data, wich is a list containing an N x F
tensor for each time step where N is the batch size and F is the
input dimension
seq_length: a vector containing the sequence lengths
name: [optional] the name of the operation
Returns:
non sequential data, which is a TxF tensor where T is the sum of all
sequence lengths
'''
with tf.name_scope(name or 'seq2nonseq'):
#convert the list for each time step to a list for each sequence
sequences = tf.unpack(tf.pack(tensorlist), axis=1)
#remove the padding from sequences
sequences = [tf.gather(sequences[s], tf.range(seq_length[s]))
for s in range(len(sequences))]
#concatenate the sequences
tensor = tf.concat(0, sequences)
return tensor
评论列表
文章目录