def test_cross_val_predict():
# Make sure it works in cross_val_predict for multiclass.
X, y = load_iris(return_X_y=True)
y = LabelBinarizer().fit_transform(y)
X = StandardScaler().fit_transform(X)
mlp = MLPClassifier(n_epochs=10,
solver_kwargs={'learning_rate': 0.05},
random_state=4567).fit(X, y)
cv = KFold(n_splits=4, random_state=457, shuffle=True)
y_oos = cross_val_predict(mlp, X, y, cv=cv, method='predict_proba')
auc = roc_auc_score(y, y_oos, average=None)
assert np.all(auc >= 0.96)
评论列表
文章目录