main.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:deeplearning 作者: turiphro 项目源码 文件源码
def mnist():
    #digits = datasets.load_digits() # subsampled version
    mnist = datasets.fetch_mldata("MNIST original")
    print("Got the data.")
    X, y = mnist.data / 255., mnist.target
    X_train, X_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]

    #images_and_labels = list(zip(digits.images, digits.target))
    #for index, (image, label) in enumerate(images_and_labels[:4]):
    #    plt.subplot(2, 4, index + 1)
    #    plt.axis('off')
    #    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    #    plt.title('Training: %i' % label)

    classifiers = [
        #("SVM", svm.SVC(gamma=0.001)), # TODO doesn't finish; needs downsampled version?
        ("NN", MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
                             solver='sgd', verbose=10, tol=1e-4, random_state=1,
                             learning_rate_init=.1)),
    ]

    for name, classifier in classifiers:
        print(name)
        classifier.fit(X_train, y_train)
        predicted = classifier.predict(X_test)

        print("Classification report for classifier %s:\n%s\n"
              % (classifier, metrics.classification_report(y_test, predicted)))
        print("Confusion matrix:\n%s" % metrics.confusion_matrix(y_test, predicted))

        #images_and_predictions = list(zip(digits.images[n_samples / 2:], predicted))
        #for index, (image, prediction) in enumerate(images_and_predictions[:4]):
        #    plt.subplot(2, 4, index + 5)
        #    plt.axis('off')
        #    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        #    plt.title('Prediction: %i' % prediction)

        #plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号