def test_ovr_partial_fit():
# Test if partial_fit is working as intented
X, y = shuffle(iris.data, iris.target, random_state=0)
ovr = OneVsRestClassifier(MultinomialNB())
ovr.partial_fit(X[:100], y[:100], np.unique(y))
ovr.partial_fit(X[100:], y[100:])
pred = ovr.predict(X)
ovr2 = OneVsRestClassifier(MultinomialNB())
pred2 = ovr2.fit(X, y).predict(X)
assert_almost_equal(pred, pred2)
assert_equal(len(ovr.estimators_), len(np.unique(y)))
assert_greater(np.mean(y == pred), 0.65)
# Test when mini batches doesn't have all classes
ovr = OneVsRestClassifier(MultinomialNB())
ovr.partial_fit(iris.data[:60], iris.target[:60], np.unique(iris.target))
ovr.partial_fit(iris.data[60:], iris.target[60:])
pred = ovr.predict(iris.data)
ovr2 = OneVsRestClassifier(MultinomialNB())
pred2 = ovr2.fit(iris.data, iris.target).predict(iris.data)
assert_almost_equal(pred, pred2)
assert_equal(len(ovr.estimators_), len(np.unique(iris.target)))
assert_greater(np.mean(iris.target == pred), 0.65)
评论列表
文章目录