def print_metrics(clf):
#scores = cross_validation.cross_val_score(clf,features,labels,cv=5,scoring='accuracy')
#print 'Accuracy:',scores.mean()
cv = cross_validation.StratifiedKFold(labels,n_folds=5)
mean_tpr = 0.0
mean_fpr = np.linspace(0,1,100)
all_tpr = []
for i, (train,test) in enumerate(cv):
probas_ = clf.fit(features[train],labels[train]).predict_proba(features[test])
fpr,tpr,thresholds = metrics.roc_curve(labels[test],probas_[:,1])
mean_tpr += interp(mean_fpr,fpr,tpr)
mean_tpr[0] = 0.0
roc_auc = metrics.auc(fpr,tpr)
plt.plot(fpr,tpr,lw=1,label='ROC fold %d (area = %0.2f)' % (i,roc_auc))
plt.plot([0,1],[0,1],'--',color=(0.6,0.6,0.6),label='Luck')
mean_tpr /= len(cv)
mean_tpr[-1] = 1.0
mean_auc = metrics.auc(mean_fpr, mean_tpr)
plt.plot(mean_fpr, mean_tpr, 'k--',
label='Mean ROC (area = %0.2f)' % mean_auc, lw=2)
plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.savefig('auc_sent.png')
TwitterResults.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录