test_ransac.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_ransac_fit_sample_weight():
    ransac_estimator = RANSACRegressor(random_state=0)
    n_samples = y.shape[0]
    weights = np.ones(n_samples)
    ransac_estimator.fit(X, y, weights)
    # sanity check
    assert_equal(ransac_estimator.inlier_mask_.shape[0], n_samples)

    ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_
                                   ).astype(np.bool_)
    ref_inlier_mask[outliers] = False
    # check that mask is correct
    assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)

    # check that fit(X)  = fit([X1, X2, X3],sample_weight = [n1, n2, n3]) where
    #   X = X1 repeated n1 times, X2 repeated n2 times and so forth
    random_state = check_random_state(0)
    X_ = random_state.randint(0, 200, [10, 1])
    y_ = np.ndarray.flatten(0.2 * X_ + 2)
    sample_weight = random_state.randint(0, 10, 10)
    outlier_X = random_state.randint(0, 1000, [1, 1])
    outlier_weight = random_state.randint(0, 10, 1)
    outlier_y = random_state.randint(-1000, 0, 1)

    X_flat = np.append(np.repeat(X_, sample_weight, axis=0),
                       np.repeat(outlier_X, outlier_weight, axis=0), axis=0)
    y_flat = np.ndarray.flatten(np.append(np.repeat(y_, sample_weight, axis=0),
                                np.repeat(outlier_y, outlier_weight, axis=0),
                                          axis=0))
    ransac_estimator.fit(X_flat, y_flat)
    ref_coef_ = ransac_estimator.estimator_.coef_

    sample_weight = np.append(sample_weight, outlier_weight)
    X_ = np.append(X_, outlier_X, axis=0)
    y_ = np.append(y_, outlier_y)
    ransac_estimator.fit(X_, y_, sample_weight)

    assert_almost_equal(ransac_estimator.estimator_.coef_, ref_coef_)

    # check that if base_estimator.fit doesn't support
    # sample_weight, raises error
    base_estimator = Lasso()
    ransac_estimator = RANSACRegressor(base_estimator)
    assert_raises(ValueError, ransac_estimator.fit, X, y, weights)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号