def test_classifier(self):
index = [i for i in range(len(self.iris.data))]
rf = RandomForestClassifier()
jrf = JoblibedClassifier(rf, "rf", cache_dir='')
jrf.fit(self.iris.data, self.iris.target, index)
prediction = jrf.predict(self.iris.data, index)
score = accuracy_score(self.iris.target, prediction)
self.assertGreater(score, 0.9, "Failed with score = {0}".format(score))
rf = RandomForestClassifier(n_estimators=20)
jrf = JoblibedClassifier(rf, "rf", cache_dir='')
jrf.fit(self.iris.data, self.iris.target)
index = [i for i in range(len(self.iris.data))]
prediction2 = jrf.predict(self.iris.data, index)
self.assertTrue((prediction == prediction2).all())
评论列表
文章目录