def mnist():
#digits = datasets.load_digits() # subsampled version
mnist = datasets.fetch_mldata("MNIST original")
print("Got the data.")
X, y = mnist.data / 255., mnist.target
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
#images_and_labels = list(zip(digits.images, digits.target))
#for index, (image, label) in enumerate(images_and_labels[:4]):
# plt.subplot(2, 4, index + 1)
# plt.axis('off')
# plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
# plt.title('Training: %i' % label)
classifiers = [
#("SVM", svm.SVC(gamma=0.001)), # TODO doesn't finish; needs downsampled version?
("NN", MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
solver='sgd', verbose=10, tol=1e-4, random_state=1,
learning_rate_init=.1)),
]
for name, classifier in classifiers:
print(name)
classifier.fit(X_train, y_train)
predicted = classifier.predict(X_test)
print("Classification report for classifier %s:\n%s\n"
% (classifier, metrics.classification_report(y_test, predicted)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(y_test, predicted))
#images_and_predictions = list(zip(digits.images[n_samples / 2:], predicted))
#for index, (image, prediction) in enumerate(images_and_predictions[:4]):
# plt.subplot(2, 4, index + 5)
# plt.axis('off')
# plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
# plt.title('Prediction: %i' % prediction)
#plt.show()
评论列表
文章目录