def test_cross_val_predict():
"""Make sure it works in cross_val_predict."""
X, y = load_iris(return_X_y=True)
X = StandardScaler().fit_transform(X)
clf = FMClassifier(rank=2, solver='L-BFGS-B', random_state=4567).fit(X, y)
cv = KFold(n_splits=4, random_state=457, shuffle=True)
y_oos = cross_val_predict(clf, X, y, cv=cv, method='predict')
acc = accuracy_score(y, y_oos)
assert acc >= 0.90, "accuracy is too low for iris in cross_val_predict!"
评论列表
文章目录