def test_cart_d1_agrees_with_scikit():
d_cart = GaussCART(X, y, 1)
d_pred = d_cart.predict(X)
sk_cart = tree.DecisionTreeRegressor(max_depth=1)
sk_cart = sk_cart.fit(X, y)
sk_pred = sk_cart.predict(X)
d_error = np.round(sose(y, d_pred), 6)
sk_error = np.round(sose(y, sk_pred), 6)
assert d_error == sk_error
评论列表
文章目录