def get_best_threshold(y_ref, y_pred_score, plot=False):
""" Get threshold on scores that maximizes f1 score.
Parameters
----------
y_ref : array
Reference labels (binary).
y_pred_score : array
Predicted scores.
plot : bool
If true, plot ROC curve
Returns
-------
best_threshold : float
threshold on score that maximized f1 score
max_fscore : float
f1 score achieved at best_threshold
"""
pos_weight = 1.0 - float(len(y_ref[y_ref == 1]))/float(len(y_ref))
neg_weight = 1.0 - float(len(y_ref[y_ref == 0]))/float(len(y_ref))
sample_weight = np.zeros(y_ref.shape)
sample_weight[y_ref == 1] = pos_weight
sample_weight[y_ref == 0] = neg_weight
print "max prediction value = %s" % np.max(y_pred_score)
print "min prediction value = %s" % np.min(y_pred_score)
precision, recall, thresholds = \
metrics.precision_recall_curve(y_ref, y_pred_score, pos_label=1,
sample_weight=sample_weight)
beta = 1.0
btasq = beta**2.0
fbeta_scores = (1.0 + btasq)*(precision*recall)/((btasq*precision)+recall)
max_fscore = fbeta_scores[np.nanargmax(fbeta_scores)]
best_threshold = thresholds[np.nanargmax(fbeta_scores)]
if plot:
plt.figure(1)
plt.subplot(1, 2, 1)
plt.plot(recall, precision, '.b', label='PR curve')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower right", frameon=True)
plt.subplot(1, 2, 2)
plt.plot(thresholds, fbeta_scores[:-1], '.r', label='f1-score')
plt.xlabel('Probability Threshold')
plt.ylabel('F1 score')
plt.show()
plot_data = (recall, precision, thresholds, fbeta_scores[:-1])
return best_threshold, max_fscore, plot_data
experiment_utils.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录