tf_queuing.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:adventures-in-ml-code 作者: adventuresinML 项目源码 文件源码
def cifar_shuffle_batch():
    batch_size = 128
    num_threads = 16
    # create a list of all our filenames
    filename_list = [data_path + 'data_batch_{}.bin'.format(i + 1) for i in range(5)]
    # create a filename queue
    # file_q = cifar_filename_queue(filename_list)
    file_q = tf.train.string_input_producer(filename_list)
    # read the data - this contains a FixedLengthRecordReader object which handles the
    # de-queueing of the files.  It returns a processed image and label, with shapes
    # ready for a convolutional neural network
    image, label = read_data(file_q)
    # setup minimum number of examples that can remain in the queue after dequeuing before blocking
    # occurs (i.e. enqueuing is forced) - the higher the number the better the mixing but
    # longer initial load time
    min_after_dequeue = 10000
    # setup the capacity of the queue - this is based on recommendations by TensorFlow to ensure
    # good mixing
    capacity = min_after_dequeue + (num_threads + 1) * batch_size
    # image_batch, label_batch = cifar_shuffle_queue_batch(image, label, batch_size, num_threads)
    image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size, capacity, min_after_dequeue,
                                                      num_threads=num_threads)
    # now run the training
    cifar_run(image_batch, label_batch)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号