detection.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:FeatureSqueezing 作者: uvasrg 项目源码 文件源码
def train_detector(x_train, y_train, x_val, y_val):
    fpr, tpr, thresholds = roc_curve(y_train, x_train)
    accuracy = [ sklearn.metrics.accuracy_score(y_train, x_train>threshold, normalize=True, sample_weight=None) for threshold in thresholds ]
    roc_auc = auc(fpr, tpr)

    idx_best = np.argmax(accuracy)
    print "Best training accuracy: %.4f, TPR(Recall): %.4f, FPR: %.4f @%.4f" % (accuracy[idx_best], tpr[idx_best], fpr[idx_best], thresholds[idx_best])
    print "ROC_AUC: %.4f" % roc_auc

    accuracy_val = [ sklearn.metrics.accuracy_score(y_val, x_val>threshold, normalize=True, sample_weight=None) for threshold in thresholds ]
    tpr_val, fpr_val = zip(*[ get_tpr_fpr(y_val, x_val, threshold)  for threshold in thresholds  ])
    # roc_auc_val = auc(fpr_val, tpr_val)
    print "Validation accuracy: %.4f, TPR(Recall): %.4f, FPR: %.4f @%.4f" % (accuracy_val[idx_best], tpr_val[idx_best], fpr_val[idx_best], thresholds[idx_best])

    return threshold, accuracy_val, fpr_val, tpr_val
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号