def test_ransac_max_trials():
base_estimator = LinearRegression()
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2,
residual_threshold=5, max_trials=0,
random_state=0)
assert_raises(ValueError, ransac_estimator.fit, X, y)
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2,
residual_threshold=5, max_trials=11,
random_state=0)
assert getattr(ransac_estimator, 'n_trials_', None) is None
ransac_estimator.fit(X, y)
assert_equal(ransac_estimator.n_trials_, 2)
评论列表
文章目录