classifier_tf.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def classification_metrics(y, y_pred, threshold):
    metrics = {}
    metrics['threshold'] = threshold_from_predictions(y, y_pred, 0)
    metrics['np.std(y_pred)'] = np.std(y_pred)
    metrics['positive_frac_batch'] = float(np.count_nonzero(y == True)) / len(y)
    denom = np.count_nonzero(y == False)
    num = np.count_nonzero(np.logical_and(y == False, y_pred >= threshold))
    if denom > 0:
        metrics['fpr'] = float(num) / float(denom)
    if any(y) and not all(y):
        metrics['auc'] = roc_auc_score(y, y_pred)
        y_pred_bool = y_pred >= threshold
        if (any(y_pred_bool) and not all(y_pred_bool)):
            metrics['precision'] = precision_score(np.array(y, dtype=np.float32), y_pred_bool)
            metrics['recall'] = recall_score(y, y_pred_bool)
    return metrics
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号