def test_neighbors_digits():
# Sanity check on the digits dataset
# the 'brute' algorithm has been observed to fail if the input
# dtype is uint8 due to overflow in distance calculations.
X = digits.data.astype('uint8')
Y = digits.target
(n_samples, n_features) = X.shape
train_test_boundary = int(n_samples * 0.8)
train = np.arange(0, train_test_boundary)
test = np.arange(train_test_boundary, n_samples)
(X_train, Y_train, X_test, Y_test) = X[train], Y[train], X[test], Y[test]
clf = neighbors.KNeighborsClassifier(n_neighbors=1, algorithm='brute')
score_uint8 = clf.fit(X_train, Y_train).score(X_test, Y_test)
score_float = clf.fit(X_train.astype(float), Y_train).score(
X_test.astype(float), Y_test)
assert_equal(score_uint8, score_float)
评论列表
文章目录