metrics.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:data-preppy 作者: gurgeh 项目源码 文件源码
def metric(model, test_csv, fname):
    X, Y_true, headers = get_XY(test_csv)
    Y_pred = model.predict(X)
    try:
        print confusion_matrix(Y_true, [a[0] > 0.5 for a in Y_pred])
    except IndexError:
        print confusion_matrix(Y_true, [a > 0.5 for a in Y_pred])

    fpr, tpr, _ = roc_curve(Y_true, Y_pred)
    roc_auc = roc_auc_score(Y_true, Y_pred)

    plt.figure()
    plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC - %s' % fname.split('/')[-1])
    plt.legend(loc="lower right")
    plt.show()
    plt.savefig(fname + ' - roc.png')
    return plt
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号