seq_convertors.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tfkaldi 作者: vrenkens 项目源码 文件源码
def nonseq2seq(tensor, seq_length, length, name=None):
    '''
    Convert non sequential data to sequential data

    Args:
        tensor: non sequential data, which is a TxF tensor where T is the sum of
            all sequence lengths
        seq_length: a vector containing the sequence lengths
        length: the constant length of the output sequences
        name: [optional] the name of the operation

    Returns:
        sequential data, wich is a list containing an N x F
        tensor for each time step where N is the batch size and F is the
        input dimension
    '''

    with tf.name_scope(name or'nonseq2seq'):
        #get the cumulated sequence lengths to specify the positions in tensor
        cum_seq_length = tf.concat(0, [tf.constant([0]), tf.cumsum(seq_length)])

        #get the indices in the tensor for each sequence
        indices = [tf.range(cum_seq_length[l], cum_seq_length[l+1])
                   for l in range(int(seq_length.get_shape()[0]))]

        #create the non-padded sequences
        sequences = [tf.gather(tensor, i) for i in indices]

        #pad the sequences with zeros
        sequences = [tf.pad(sequences[s], [[0, length-seq_length[s]], [0, 0]])
                     for s in range(len(sequences))]

        #specify that the sequences have been padded to the constant length
        for seq in sequences:
            seq.set_shape([length, int(tensor.get_shape()[1])])

        #convert the list for eqch sequence to a list for eqch time step
        tensorlist = tf.unpack(tf.pack(sequences), axis=1)

    return tensorlist
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号