def test_ovr_multilabel_dataset():
base_clf = MultinomialNB(alpha=1)
for au, prec, recall in zip((True, False), (0.51, 0.66), (0.51, 0.80)):
X, Y = datasets.make_multilabel_classification(n_samples=100,
n_features=20,
n_classes=5,
n_labels=2,
length=50,
allow_unlabeled=au,
random_state=0)
X_train, Y_train = X[:80], Y[:80]
X_test, Y_test = X[80:], Y[80:]
clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train)
Y_pred = clf.predict(X_test)
assert_true(clf.multilabel_)
assert_almost_equal(precision_score(Y_test, Y_pred, average="micro"),
prec,
decimal=2)
assert_almost_equal(recall_score(Y_test, Y_pred, average="micro"),
recall,
decimal=2)
评论列表
文章目录