test_cross_validation.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_cross_val_score_mask():
    # test that cross_val_score works with boolean masks
    svm = SVC(kernel="linear")
    iris = load_iris()
    X, y = iris.data, iris.target
    cv_indices = cval.KFold(len(y), 5)
    scores_indices = cval.cross_val_score(svm, X, y, cv=cv_indices)
    cv_indices = cval.KFold(len(y), 5)
    cv_masks = []
    for train, test in cv_indices:
        mask_train = np.zeros(len(y), dtype=np.bool)
        mask_test = np.zeros(len(y), dtype=np.bool)
        mask_train[train] = 1
        mask_test[test] = 1
        cv_masks.append((train, test))
    scores_masks = cval.cross_val_score(svm, X, y, cv=cv_masks)
    assert_array_equal(scores_indices, scores_masks)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号