def test_ransac_multi_dimensional_targets():
base_estimator = LinearRegression()
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2,
residual_threshold=5, random_state=0)
# 3-D target values
yyy = np.column_stack([y, y, y])
# Estimate parameters of corrupted data
ransac_estimator.fit(X, yyy)
# Ground truth / reference inlier mask
ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_
).astype(np.bool_)
ref_inlier_mask[outliers] = False
assert_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
# XXX: Remove in 0.20
评论列表
文章目录