tfrecords_io.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:pit-speech-separation 作者: snsun 项目源码 文件源码
def get_padded_batch(file_list, batch_size, input_size, output_size,
                     num_enqueuing_threads=4, num_epochs=1, shuffle=True):
    """Reads batches of SequenceExamples from TFRecords and pads them.
    Can deal with variable length SequenceExamples by padding each batch to the
    length of the longest sequence with zeros.
    Args:
        file_list: A list of paths to TFRecord files containing SequenceExamples.
        batch_size: The number of SequenceExamples to include in each batch.
        input_size: The size of each input vector. The returned batch of inputs
            will have a shape [batch_size, num_steps, input_size].
        num_enqueuing_threads: The number of threads to use for enqueuing
            SequenceExamples.
    Returns:
        inputs: A tensor of shape [batch_size, num_steps, input_size] of floats32s.
        labels: A tensor of shape [batch_size, num_steps] of float32s.
        lengths: A tensor of shape [batch_size] of int32s. The lengths of each
            SequenceExample before padding.
    """
    file_queue = tf.train.string_input_producer(
        file_list, num_epochs=num_epochs, shuffle=shuffle)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(file_queue)

    sequence_features = {
        'inputs': tf.FixedLenSequenceFeature(shape=[input_size],
                                         dtype=tf.float32),
        'labels': tf.FixedLenSequenceFeature(shape=[output_size],
                                         dtype=tf.float32),
         'genders': tf.FixedLenSequenceFeature(shape=[2], dtype=tf.float32)}

    _, sequence = tf.parse_single_sequence_example(
        serialized_example, sequence_features=sequence_features)

    length = tf.shape(sequence['inputs'])[0]

    capacity = 1000 + (num_enqueuing_threads + 1) * batch_size
    queue = tf.PaddingFIFOQueue(
        capacity=capacity,
        dtypes=[tf.float32, tf.float32, tf.float32, tf.int32],
        shapes=[(None, input_size), (None, output_size),(1,2), ()])

    enqueue_ops = [queue.enqueue([sequence['inputs'],
                                  sequence['labels'],
                                  sequence['genders'],
                                  length])] * num_enqueuing_threads

    tf.train.add_queue_runner(tf.train.QueueRunner(queue, enqueue_ops))
    return queue.dequeue_many(batch_size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号