test_mlp.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_lbfgs_classification():
    # Test lbfgs on classification.
    # It should achieve a score higher than 0.95 for the binary and multi-class
    # versions of the digits dataset.
    for X, y in classification_datasets:
        X_train = X[:150]
        y_train = y[:150]
        X_test = X[150:]

        expected_shape_dtype = (X_test.shape[0], y_train.dtype.kind)

        for activation in ACTIVATION_TYPES:
            mlp = MLPClassifier(algorithm='l-bfgs', hidden_layer_sizes=50,
                                max_iter=150, shuffle=True, random_state=1,
                                activation=activation)
            mlp.fit(X_train, y_train)
            y_predict = mlp.predict(X_test)
            assert_greater(mlp.score(X_train, y_train), 0.95)
            assert_equal((y_predict.shape[0], y_predict.dtype.kind),
                         expected_shape_dtype)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号