def test_ovr_multilabel():
# Toy dataset where features correspond directly to labels.
X = np.array([[0, 4, 5], [0, 5, 0], [3, 3, 3], [4, 0, 6], [6, 0, 0]])
y = np.array([[0, 1, 1],
[0, 1, 0],
[1, 1, 1],
[1, 0, 1],
[1, 0, 0]])
for base_clf in (MultinomialNB(), LinearSVC(random_state=0),
LinearRegression(), Ridge(),
ElasticNet(), Lasso(alpha=0.5)):
clf = OneVsRestClassifier(base_clf).fit(X, y)
y_pred = clf.predict([[0, 4, 4]])[0]
assert_array_equal(y_pred, [0, 1, 1])
assert_true(clf.multilabel_)
评论列表
文章目录