def test_sample_order_invariance_multilabel_and_multioutput():
random_state = check_random_state(0)
# Generate some data
y_true = random_state.randint(0, 2, size=(20, 25))
y_pred = random_state.randint(0, 2, size=(20, 25))
y_score = random_state.normal(size=y_true.shape)
y_true_shuffle, y_pred_shuffle, y_score_shuffle = shuffle(y_true,
y_pred,
y_score,
random_state=0)
for name in MULTILABELS_METRICS:
metric = ALL_METRICS[name]
assert_almost_equal(metric(y_true, y_pred),
metric(y_true_shuffle, y_pred_shuffle),
err_msg="%s is not sample order invariant"
% name)
for name in THRESHOLDED_MULTILABEL_METRICS:
metric = ALL_METRICS[name]
assert_almost_equal(metric(y_true, y_score),
metric(y_true_shuffle, y_score_shuffle),
err_msg="%s is not sample order invariant"
% name)
for name in MULTIOUTPUT_METRICS:
metric = ALL_METRICS[name]
assert_almost_equal(metric(y_true, y_score),
metric(y_true_shuffle, y_score_shuffle),
err_msg="%s is not sample order invariant"
% name)
assert_almost_equal(metric(y_true, y_pred),
metric(y_true_shuffle, y_pred_shuffle),
err_msg="%s is not sample order invariant"
% name)
评论列表
文章目录