def test_safe_split_with_precomputed_kernel():
clf = SVC()
clfp = SVC(kernel="precomputed")
iris = load_iris()
X, y = iris.data, iris.target
K = np.dot(X, X.T)
cv = cval.ShuffleSplit(X.shape[0], test_size=0.25, random_state=0)
tr, te = list(cv)[0]
X_tr, y_tr = cval._safe_split(clf, X, y, tr)
K_tr, y_tr2 = cval._safe_split(clfp, K, y, tr)
assert_array_almost_equal(K_tr, np.dot(X_tr, X_tr.T))
X_te, y_te = cval._safe_split(clf, X, y, te, tr)
K_te, y_te2 = cval._safe_split(clfp, K, y, te, tr)
assert_array_almost_equal(K_te, np.dot(X_te, X_tr.T))
评论列表
文章目录