def test_ransac_is_data_valid():
def is_data_valid(X, y):
assert_equal(X.shape[0], 2)
assert_equal(y.shape[0], 2)
return False
X = np.random.rand(10, 2)
y = np.random.rand(10, 1)
base_estimator = LinearRegression()
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2,
residual_threshold=5,
is_data_valid=is_data_valid,
random_state=0)
assert_raises(ValueError, ransac_estimator.fit, X, y)
评论列表
文章目录