plot_test_results.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:marseille 作者: vene 项目源码 文件源码
def arg_p_r_f(Y_true, Y_pred, labels, **kwargs):

    macro_p = []
    macro_r = []
    macro_f = []

    micro_true = []
    micro_pred = []

    for y_true, y_pred in zip(Y_true, Y_pred):
        p, r, f, _ = precision_recall_fscore_support(y_true, y_pred,
                                                     **kwargs)
        macro_p.append(p)
        macro_r.append(r)
        macro_f.append(f)

        micro_true.extend(y_true)
        micro_pred.extend(y_pred)

    micro_p, micro_r, micro_f, _ = precision_recall_fscore_support(
        micro_true, micro_pred, **kwargs
    )
    kwargs.pop('average')
    per_class_fs = f1_score(micro_true, micro_pred, average=None, **kwargs)

    res = {
        'p_macro': np.mean(macro_p),
        'r_macro': np.mean(macro_r),
        'f_macro': np.mean(macro_f),
        'p_micro': micro_p,
        'r_micro': micro_r,
        'f_micro': micro_f
    }

    for label, per_class_f in zip(sorted(labels), per_class_fs):
        res['f_class_{}'.format(label)] = per_class_f

    return res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号