test_fm_classifier.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:muffnn 作者: civisanalytics 项目源码 文件源码
def test_cross_val_predict():
    """Make sure it works in cross_val_predict."""

    X, y = load_iris(return_X_y=True)
    X = StandardScaler().fit_transform(X)

    clf = FMClassifier(rank=2, solver='L-BFGS-B', random_state=4567).fit(X, y)

    cv = KFold(n_splits=4, random_state=457, shuffle=True)
    y_oos = cross_val_predict(clf, X, y, cv=cv, method='predict')
    acc = accuracy_score(y, y_oos)

    assert acc >= 0.90, "accuracy is too low for iris in cross_val_predict!"
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号