def test_ransac_is_model_valid():
def is_model_valid(estimator, X, y):
assert_equal(X.shape[0], 2)
assert_equal(y.shape[0], 2)
return False
base_estimator = LinearRegression()
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2,
residual_threshold=5,
is_model_valid=is_model_valid,
random_state=0)
assert_raises(ValueError, ransac_estimator.fit, X, y)
评论列表
文章目录