test_multiclass.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_ovo_partial_fit_predict():
    X, y = shuffle(iris.data, iris.target)
    ovo1 = OneVsOneClassifier(MultinomialNB())
    ovo1.partial_fit(X[:100], y[:100], np.unique(y))
    ovo1.partial_fit(X[100:], y[100:])
    pred1 = ovo1.predict(X)

    ovo2 = OneVsOneClassifier(MultinomialNB())
    ovo2.fit(X, y)
    pred2 = ovo2.predict(X)
    assert_equal(len(ovo1.estimators_), n_classes * (n_classes - 1) / 2)
    assert_greater(np.mean(y == pred1), 0.65)
    assert_almost_equal(pred1, pred2)

    # Test when mini-batches don't have all target classes
    ovo1 = OneVsOneClassifier(MultinomialNB())
    ovo1.partial_fit(iris.data[:60], iris.target[:60], np.unique(iris.target))
    ovo1.partial_fit(iris.data[60:], iris.target[60:])
    pred1 = ovo1.predict(iris.data)
    ovo2 = OneVsOneClassifier(MultinomialNB())
    pred2 = ovo2.fit(iris.data, iris.target).predict(iris.data)

    assert_almost_equal(pred1, pred2)
    assert_equal(len(ovo1.estimators_), len(np.unique(iris.target)))
    assert_greater(np.mean(iris.target == pred1), 0.65)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号