metrics.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:TemporalConvolutionalNetworks 作者: colincsl 项目源码 文件源码
def macro_accuracy(P, Y, n_classes, bg_class=None, return_all=False, **kwargs):
    def macro_(P, Y, n_classes=None, bg_class=None, return_all=False):
        conf_matrix = sm.confusion_matrix(Y, P, labels=np.arange(n_classes))
        conf_matrix = conf_matrix/(conf_matrix.sum(0)[:,None]+1e-5)
        conf_matrix = np.nan_to_num(conf_matrix)
        diag = conf_matrix.diagonal()*100.

        # Remove background score
        if bg_class is not None:
            diag = np.array([diag[i] for i in range(n_classes) if i!=bg_class])

        macro = diag.mean()
        if return_all:
            return macro, diag
        else:
            return macro

    if type(P) == list:
        out = [macro_(P[i], Y[i], n_classes=n_classes, bg_class=bg_class, return_all=return_all) for i in range(len(P))]
        if return_all:
            return (np.mean([o[0] for o in out]), np.mean([o[1] for o in out],0))
        else:
            return np.mean(out)
    else:
        return macro_(P,Y, n_classes=n_classes, bg_class=bg_class, return_all=return_all)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号