def import_mnist(validation_size=0):
"""
This import mnist and saves the data as an object of our DataSet class
:param concat_val: Concatenate training and validation
:return:
"""
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
ONE_HOT = True
TRAIN_DIR = 'experiments/data/MNIST_data'
local_file = base.maybe_download(TRAIN_IMAGES, TRAIN_DIR,
SOURCE_URL + TRAIN_IMAGES)
with open(local_file) as f:
train_images = extract_images(f)
local_file = base.maybe_download(TRAIN_LABELS, TRAIN_DIR,
SOURCE_URL + TRAIN_LABELS)
with open(local_file) as f:
train_labels = extract_labels(f, one_hot=ONE_HOT)
local_file = base.maybe_download(TEST_IMAGES, TRAIN_DIR,
SOURCE_URL + TEST_IMAGES)
with open(local_file) as f:
test_images = extract_images(f)
local_file = base.maybe_download(TEST_LABELS, TRAIN_DIR,
SOURCE_URL + TEST_LABELS)
with open(local_file) as f:
test_labels = extract_labels(f, one_hot=ONE_HOT)
validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
# process images
train_images = process_mnist(train_images)
validation_images = process_mnist(validation_images)
test_images = process_mnist(test_images)
# standardize data
train_mean, train_std = get_data_info(train_images)
train_images = standardize_data(train_images, train_mean, train_std)
validation_images = standardize_data(validation_images, train_mean, train_std)
test_images = standardize_data(test_images, train_mean, train_std)
data = DataSet(train_images, train_labels)
test = DataSet(test_images, test_labels)
val = DataSet(validation_images, validation_labels)
return data, test, val
评论列表
文章目录