seq_convertors.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tfkaldi 作者: vrenkens 项目源码 文件源码
def seq2nonseq(tensorlist, seq_length, name=None):
    '''
    Convert sequential data to non sequential data

    Args:
        tensorlist: the sequential data, wich is a list containing an N x F
            tensor for each time step where N is the batch size and F is the
            input dimension
        seq_length: a vector containing the sequence lengths
        name: [optional] the name of the operation

    Returns:
        non sequential data, which is a TxF tensor where T is the sum of all
        sequence lengths
    '''

    with tf.name_scope(name or 'seq2nonseq'):
        #convert the list for each time step to a list for each sequence
        sequences = tf.unpack(tf.pack(tensorlist), axis=1)

        #remove the padding from sequences
        sequences = [tf.gather(sequences[s], tf.range(seq_length[s]))
                     for s in range(len(sequences))]

        #concatenate the sequences
        tensor = tf.concat(0, sequences)

    return tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号