test_split.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_cross_validator_with_default_params():
    n_samples = 4
    n_unique_labels = 4
    n_folds = 2
    p = 2
    n_iter = 10  # (the default value)

    X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
    X_1d = np.array([1, 2, 3, 4])
    y = np.array([1, 1, 2, 2])
    labels = np.array([1, 2, 3, 4])
    loo = LeaveOneOut()
    lpo = LeavePOut(p)
    kf = KFold(n_folds)
    skf = StratifiedKFold(n_folds)
    lolo = LeaveOneLabelOut()
    lopo = LeavePLabelOut(p)
    ss = ShuffleSplit(random_state=0)
    ps = PredefinedSplit([1, 1, 2, 2])  # n_splits = np of unique folds = 2

    loo_repr = "LeaveOneOut()"
    lpo_repr = "LeavePOut(p=2)"
    kf_repr = "KFold(n_folds=2, random_state=None, shuffle=False)"
    skf_repr = "StratifiedKFold(n_folds=2, random_state=None, shuffle=False)"
    lolo_repr = "LeaveOneLabelOut()"
    lopo_repr = "LeavePLabelOut(n_labels=2)"
    ss_repr = ("ShuffleSplit(n_iter=10, random_state=0, test_size=0.1, "
               "train_size=None)")
    ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))"

    n_splits = [n_samples, comb(n_samples, p), n_folds, n_folds,
                n_unique_labels, comb(n_unique_labels, p), n_iter, 2]

    for i, (cv, cv_repr) in enumerate(zip(
            [loo, lpo, kf, skf, lolo, lopo, ss, ps],
            [loo_repr, lpo_repr, kf_repr, skf_repr, lolo_repr, lopo_repr,
             ss_repr, ps_repr])):
        # Test if get_n_splits works correctly
        assert_equal(n_splits[i], cv.get_n_splits(X, y, labels))

        # Test if the cross-validator works as expected even if
        # the data is 1d
        np.testing.assert_equal(list(cv.split(X, y, labels)),
                                list(cv.split(X_1d, y, labels)))
        # Test that train, test indices returned are integers
        for train, test in cv.split(X, y, labels):
            assert_equal(np.asarray(train).dtype.kind, 'i')
            assert_equal(np.asarray(train).dtype.kind, 'i')

        # Test if the repr works without any errors
        assert_equal(cv_repr, repr(cv))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号