test_cross_validation.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_cross_val_predict():
    boston = load_boston()
    X, y = boston.data, boston.target
    cv = cval.KFold(len(boston.target))

    est = Ridge()

    # Naive loop (should be same as cross_val_predict):
    preds2 = np.zeros_like(y)
    for train, test in cv:
        est.fit(X[train], y[train])
        preds2[test] = est.predict(X[test])

    preds = cval.cross_val_predict(est, X, y, cv=cv)
    assert_array_almost_equal(preds, preds2)

    preds = cval.cross_val_predict(est, X, y)
    assert_equal(len(preds), len(y))

    cv = cval.LeaveOneOut(len(y))
    preds = cval.cross_val_predict(est, X, y, cv=cv)
    assert_equal(len(preds), len(y))

    Xsp = X.copy()
    Xsp *= (Xsp > np.median(Xsp))
    Xsp = coo_matrix(Xsp)
    preds = cval.cross_val_predict(est, Xsp, y)
    assert_array_almost_equal(len(preds), len(y))

    preds = cval.cross_val_predict(KMeans(), X)
    assert_equal(len(preds), len(y))

    def bad_cv():
        for i in range(4):
            yield np.array([0, 1, 2, 3]), np.array([4, 5, 6, 7, 8])

    assert_raises(ValueError, cval.cross_val_predict, est, X, y, cv=bad_cv())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号