def nonseq2seq(tensor, seq_length, length, name=None):
'''
Convert non sequential data to sequential data
Args:
tensor: non sequential data, which is a TxF tensor where T is the sum of
all sequence lengths
seq_length: a vector containing the sequence lengths
length: the constant length of the output sequences
name: [optional] the name of the operation
Returns:
sequential data, wich is a list containing an N x F
tensor for each time step where N is the batch size and F is the
input dimension
'''
with tf.name_scope(name or'nonseq2seq'):
#get the cumulated sequence lengths to specify the positions in tensor
cum_seq_length = tf.concat(0, [tf.constant([0]), tf.cumsum(seq_length)])
#get the indices in the tensor for each sequence
indices = [tf.range(cum_seq_length[l], cum_seq_length[l+1])
for l in range(int(seq_length.get_shape()[0]))]
#create the non-padded sequences
sequences = [tf.gather(tensor, i) for i in indices]
#pad the sequences with zeros
sequences = [tf.pad(sequences[s], [[0, length-seq_length[s]], [0, 0]])
for s in range(len(sequences))]
#specify that the sequences have been padded to the constant length
for seq in sequences:
seq.set_shape([length, int(tensor.get_shape()[1])])
#convert the list for eqch sequence to a list for eqch time step
tensorlist = tf.unpack(tf.pack(sequences), axis=1)
return tensorlist
评论列表
文章目录