test_common.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_sample_order_invariance_multilabel_and_multioutput():
    random_state = check_random_state(0)

    # Generate some data
    y_true = random_state.randint(0, 2, size=(20, 25))
    y_pred = random_state.randint(0, 2, size=(20, 25))
    y_score = random_state.normal(size=y_true.shape)

    y_true_shuffle, y_pred_shuffle, y_score_shuffle = shuffle(y_true,
                                                              y_pred,
                                                              y_score,
                                                              random_state=0)

    for name in MULTILABELS_METRICS:
        metric = ALL_METRICS[name]
        assert_almost_equal(metric(y_true, y_pred),
                            metric(y_true_shuffle, y_pred_shuffle),
                            err_msg="%s is not sample order invariant"
                                    % name)

    for name in THRESHOLDED_MULTILABEL_METRICS:
        metric = ALL_METRICS[name]
        assert_almost_equal(metric(y_true, y_score),
                            metric(y_true_shuffle, y_score_shuffle),
                            err_msg="%s is not sample order invariant"
                                    % name)

    for name in MULTIOUTPUT_METRICS:
        metric = ALL_METRICS[name]
        assert_almost_equal(metric(y_true, y_score),
                            metric(y_true_shuffle, y_score_shuffle),
                            err_msg="%s is not sample order invariant"
                                    % name)
        assert_almost_equal(metric(y_true, y_pred),
                            metric(y_true_shuffle, y_pred_shuffle),
                            err_msg="%s is not sample order invariant"
                                    % name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号