def generate_prec_recall_points(clf, test_examples, test_labels, pk_file):
# Generate precision-recall points and store in a pickle file.
precision = dict()
recall = dict()
average_precision = dict()
thresholds = dict()
n_classes = len(clf.model.classes_)
y_test = label_binarize(test_labels, clf.model.classes_)
y_score = clf.predict_raw_prob(test_examples)
# It only output 1 column of positive probability.
y_score = y_score[:, 1:]
for i in range(n_classes - 1):
precision[i], recall[i], thresholds[i] = precision_recall_curve(
y_test[:, i],
y_score[:, i])
average_precision[i] = average_precision_score(y_test[:, i],
y_score[:, i])
# Compute micro-average ROC curve and ROC area
precision["micro"], recall["micro"], thresholds['micro'] = \
precision_recall_curve(y_test.ravel(), y_score.ravel())
average_precision["micro"] = average_precision_score(y_test, y_score,
average="micro")
if pk_file is not None:
with open(pk_file, 'wb') as f:
pickle.dump((precision, recall, average_precision, thresholds), f)
评论列表
文章目录