pipeline.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:RIDDLE 作者: jisungk 项目源码 文件源码
def evaluate(y_test, y_test_proba, nb_classes, path):
    from riddle import roc # here so np can be seeded before run_pipeline() call

    y_pred = [np.argmax(p) for p in y_test_proba]

    print('Confusion matrix:')
    print(confusion_matrix(y_test, y_pred))
    print()

    print('Classification report:')
    print(classification_report(y_test, y_pred, digits=3))

    print('ROC AUC values:')
    roc_auc, fpr, tpr = roc.compute_roc(y_test, y_test_proba, 
        nb_classes=nb_classes)
    roc.save_plots(roc_auc, fpr, tpr, nb_classes=nb_classes, path=path)

    for l, r in roc_auc.items():
        print('  {}: {:.5f}'.format(l, r))
    print()

# ---------------------------- PUBLIC FUNCTIONS ------------------------------ #
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号