def test_roc_curve_one_label():
y_true = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
y_pred = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
# assert there are warnings
w = UndefinedMetricWarning
fpr, tpr, thresholds = assert_warns(w, roc_curve, y_true, y_pred)
# all true labels, all fpr should be nan
assert_array_equal(fpr,
np.nan * np.ones(len(thresholds)))
assert_equal(fpr.shape, tpr.shape)
assert_equal(fpr.shape, thresholds.shape)
# assert there are warnings
fpr, tpr, thresholds = assert_warns(w, roc_curve,
[1 - x for x in y_true],
y_pred)
# all negative labels, all tpr should be nan
assert_array_equal(tpr,
np.nan * np.ones(len(thresholds)))
assert_equal(fpr.shape, tpr.shape)
assert_equal(fpr.shape, thresholds.shape)
评论列表
文章目录