test_cross_validation.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_safe_split_with_precomputed_kernel():
    clf = SVC()
    clfp = SVC(kernel="precomputed")

    iris = load_iris()
    X, y = iris.data, iris.target
    K = np.dot(X, X.T)

    cv = cval.ShuffleSplit(X.shape[0], test_size=0.25, random_state=0)
    tr, te = list(cv)[0]

    X_tr, y_tr = cval._safe_split(clf, X, y, tr)
    K_tr, y_tr2 = cval._safe_split(clfp, K, y, tr)
    assert_array_almost_equal(K_tr, np.dot(X_tr, X_tr.T))

    X_te, y_te = cval._safe_split(clf, X, y, te, tr)
    K_te, y_te2 = cval._safe_split(clfp, K, y, te, tr)
    assert_array_almost_equal(K_te, np.dot(X_te, X_tr.T))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号