insights.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:menrva 作者: amirziai 项目源码 文件源码
def clf_scores(clf, x_train, y_train, x_test, y_test):
    info = dict()

    # TODO: extend this to a confusion matrix per fold for more flexibility downstream (tuning)
    # TODO: calculate a set of ROC curves per fold instead of running it on test, currently introducing bias
    scores = cross_val_score(clf, x_train, y_train, cv=cv, n_jobs=-1)
    runtime = time()
    clf.fit(x_train, y_train)
    runtime = time() - runtime
    y_test_predicted = clf.predict(x_test)
    info['runtime'] = runtime
    info['accuracy'] = min(scores)
    info['accuracy_test'] = accuracy_score(y_test, y_test_predicted)
    info['accuracy_folds'] = scores
    info['confusion_matrix'] = confusion_matrix(y_test, y_test_predicted)
    clf.fit(x_train, y_train)
    fpr, tpr, _ = roc_curve(y_test, clf_predict_proba(clf, x_test))
    info['fpr'] = fpr
    info['tpr'] = tpr
    info['auc'] = auc(fpr, tpr)

    return info
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号