test_mlp_classifier.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:muffnn 作者: civisanalytics 项目源码 文件源码
def test_cross_val_predict():
    # Make sure it works in cross_val_predict for multiclass.

    X, y = load_iris(return_X_y=True)
    y = LabelBinarizer().fit_transform(y)
    X = StandardScaler().fit_transform(X)

    mlp = MLPClassifier(n_epochs=10,
                        solver_kwargs={'learning_rate': 0.05},
                        random_state=4567).fit(X, y)

    cv = KFold(n_splits=4, random_state=457, shuffle=True)
    y_oos = cross_val_predict(mlp, X, y, cv=cv, method='predict_proba')
    auc = roc_auc_score(y, y_oos, average=None)

    assert np.all(auc >= 0.96)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号