def test_ransac_residual_metric():
residual_metric1 = lambda dy: np.sum(np.abs(dy), axis=1)
residual_metric2 = lambda dy: np.sum(dy ** 2, axis=1)
yyy = np.column_stack([y, y, y])
base_estimator = LinearRegression()
ransac_estimator0 = RANSACRegressor(base_estimator, min_samples=2,
residual_threshold=5, random_state=0)
ransac_estimator1 = RANSACRegressor(base_estimator, min_samples=2,
residual_threshold=5, random_state=0,
residual_metric=residual_metric1)
ransac_estimator2 = RANSACRegressor(base_estimator, min_samples=2,
residual_threshold=5, random_state=0,
residual_metric=residual_metric2)
# multi-dimensional
ransac_estimator0.fit(X, yyy)
assert_warns(DeprecationWarning, ransac_estimator1.fit, X, yyy)
assert_warns(DeprecationWarning, ransac_estimator2.fit, X, yyy)
assert_array_almost_equal(ransac_estimator0.predict(X),
ransac_estimator1.predict(X))
assert_array_almost_equal(ransac_estimator0.predict(X),
ransac_estimator2.predict(X))
# one-dimensional
ransac_estimator0.fit(X, y)
assert_warns(DeprecationWarning, ransac_estimator2.fit, X, y)
assert_array_almost_equal(ransac_estimator0.predict(X),
ransac_estimator2.predict(X))
评论列表
文章目录