mnist.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:AutoGP 作者: ebonilla 项目源码 文件源码
def import_mnist(validation_size=0):
    """
    This import mnist and saves the data as an object of our DataSet class
    :param concat_val: Concatenate training and validation
    :return:
    """
    SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
    ONE_HOT = True
    TRAIN_DIR = 'experiments/data/MNIST_data'

    local_file = base.maybe_download(TRAIN_IMAGES, TRAIN_DIR,
                                     SOURCE_URL + TRAIN_IMAGES)
    with open(local_file) as f:
        train_images = extract_images(f)

    local_file = base.maybe_download(TRAIN_LABELS, TRAIN_DIR,
                                     SOURCE_URL + TRAIN_LABELS)
    with open(local_file) as f:
        train_labels = extract_labels(f, one_hot=ONE_HOT)

    local_file = base.maybe_download(TEST_IMAGES, TRAIN_DIR,
                                     SOURCE_URL + TEST_IMAGES)
    with open(local_file) as f:
        test_images = extract_images(f)

    local_file = base.maybe_download(TEST_LABELS, TRAIN_DIR,
                                     SOURCE_URL + TEST_LABELS)
    with open(local_file) as f:
        test_labels = extract_labels(f, one_hot=ONE_HOT)

    validation_images = train_images[:validation_size]
    validation_labels = train_labels[:validation_size]
    train_images = train_images[validation_size:]
    train_labels = train_labels[validation_size:]

    # process images
    train_images = process_mnist(train_images)
    validation_images = process_mnist(validation_images)
    test_images = process_mnist(test_images)

    # standardize data
    train_mean, train_std = get_data_info(train_images)
    train_images = standardize_data(train_images, train_mean, train_std)
    validation_images = standardize_data(validation_images, train_mean, train_std)
    test_images = standardize_data(test_images, train_mean, train_std)

    data = DataSet(train_images, train_labels)
    test = DataSet(test_images, test_labels)
    val = DataSet(validation_images, validation_labels)

    return data, test, val
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号