def test_predict():
estimator = Id3Estimator()
bunch = load_breast_cancer()
estimator.fit(bunch.data, bunch.target)
sample = np.array([20.57, 17.77, 132.9, 1326, 0.08474, 0.07864, 0.0869,
0.07017, 0.1812, 0.05667, 0.5435, 0.7339, 3.398, 74.08,
0.005225, 0.01308, 0.0186, 0.0134, 0.01389, 0.003532,
24.99, 23.41, 158.8, 1956, 0.1238, 0.1866, 0.2416,
0.186, 0.275, 0.08902]).reshape(1, -1)
assert_almost_equal(estimator.predict(bunch.data), bunch.target)
assert_almost_equal(estimator.predict(sample), 0)
评论列表
文章目录