def load_mnist(path, num_training=50000, num_test=10000, cnn=True, one_hot=False):
f = gzip.open(path, 'rb')
training_data, validation_data, test_data = cPickle.load(
f, encoding='iso-8859-1')
f.close()
X_train, y_train = training_data
X_validation, y_validation = validation_data
X_test, y_test = test_data
if cnn:
shape = (-1, 1, 28, 28)
X_train = X_train.reshape(shape)
X_validation = X_validation.reshape(shape)
X_test = X_test.reshape(shape)
if one_hot:
y_train = one_hot_encode(y_train, 10)
y_validation = one_hot_encode(y_validation, 10)
y_test = one_hot_encode(y_test, 10)
X_train, y_train = X_train[range(
num_training)], y_train[range(num_training)]
X_test, y_test = X_test[range(num_test)], y_test[range(num_test)]
return (X_train, y_train), (X_test, y_test)
评论列表
文章目录