data.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tensorsandbox 作者: kaizouman 项目源码 文件源码
def get_raw_input_data(test_data, data_dir):
    """Raw CIFAR10 input data ops using the Reader ops.

    Args:
        test_data: bool, indicating if one should use the test or train set.
        data_dir: Path to the CIFAR-10 data directory.

    Returns:
        image: an op producing a 32x32x3 float32 image
        label: an op producing an int32 label
    """

    # Verify first that we have a valid data directory
    if not os.path.exists(data_dir):
        raise ValueError("Data directory %s doesn't exist" % data_dir)

    # Construct a list of input file names
    batches_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
    if test_data:
        filenames = [os.path.join(batches_dir, 'test_batch.bin')]
    else:
        filenames = [os.path.join(batches_dir, 'data_batch_%d.bin' %ii)
                                        for ii in xrange(1, 6)]

    # Make sure all input files actually exist
    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)

    # Create a string input producer to cycle over file names
    filenames_queue = tf.train.string_input_producer(filenames)

    # CIFAR data samples are stored as contiguous labels and images
    label_size = 1
    image_size = IMAGE_DEPTH * IMAGE_HEIGHT * IMAGE_WIDTH

    # Instantiate a fixed length file reader
    reader = tf.FixedLengthRecordReader(label_size + image_size)

    # Read from files
    key, value = reader.read(filenames_queue)
    record_bytes = tf.decode_raw(value, tf.uint8)

    # Extract label and cast to int32
    label = tf.cast(tf.slice(record_bytes, [0], [label_size]), tf.int32)

    # Extract image and cast to float32
    image = tf.cast(tf.slice(record_bytes,
                             [label_size],
                             [image_size]),
                    tf.float32)

    # Images are stored as D x H x W vectors, but we want H x W x D
    # So we need to convert to a matrix
    image = tf.reshape(image, (IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH))
    # Transpose dimensions
    image = tf.transpose(image, (1, 2, 0))

    return (image, label)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号