def test_probability():
# Predict probabilities using SVC
# This uses cross validation, so we use a slightly bigger testing set.
for clf in (svm.SVC(probability=True, random_state=0, C=1.0),
svm.NuSVC(probability=True, random_state=0)):
clf.fit(iris.data, iris.target)
prob_predict = clf.predict_proba(iris.data)
assert_array_almost_equal(
np.sum(prob_predict, 1), np.ones(iris.data.shape[0]))
assert_true(np.mean(np.argmax(prob_predict, 1)
== clf.predict(iris.data)) > 0.9)
assert_almost_equal(clf.predict_proba(iris.data),
np.exp(clf.predict_log_proba(iris.data)), 8)
评论列表
文章目录