def test_roc_nonrepeating_thresholds():
# Test to ensure that we don't return spurious repeating thresholds.
# Duplicated thresholds can arise due to machine precision issues.
dataset = datasets.load_digits()
X = dataset['data']
y = dataset['target']
# This random forest classifier can only return probabilities
# significant to two decimal places
clf = ensemble.RandomForestClassifier(n_estimators=100, random_state=0)
# How well can the classifier predict whether a digit is less than 5?
# This task contributes floating point roundoff errors to the probabilities
train, test = slice(None, None, 2), slice(1, None, 2)
probas_pred = clf.fit(X[train], y[train]).predict_proba(X[test])
y_score = probas_pred[:, :5].sum(axis=1) # roundoff errors begin here
y_true = [yy < 5 for yy in y[test]]
# Check for repeating values in the thresholds
fpr, tpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=False)
assert_equal(thresholds.size, np.unique(np.round(thresholds, 2)).size)
评论列表
文章目录