metric.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:FunnyPyML 作者: MrPig 项目源码 文件源码
def cluster_f_measure(ytrue, pred):
    # higher is better
    assert len(ytrue) == len(pred), 'inputs length must be equal.'
    label2ix = {label: i for i, label in enumerate(np.unique(ytrue))}
    _ytrue = np.array([label2ix[v] for v in ytrue])
    nSize = len(_ytrue)
    nClassTrue = len(np.unique(ytrue))
    nClassPred = len(np.unique(pred))
    f = np.zeros((nClassTrue, nClassPred)).astype(dtype=np.float64)
    for i in xrange(nClassTrue):
        freq_i = len(_ytrue[_ytrue == i])
        for j in xrange(nClassPred):
            freq_j = len(pred[pred == j])
            freq_i_j = float(len(filter(lambda x: x == j, pred[_ytrue == i])))
            precision = freq_i_j / freq_j if freq_j != 0 else 0
            recall = freq_i_j / freq_i if freq_i != 0 else 0
            if precision == 0 or recall == 0:
                f[i, j] = 0.
            else:
                f[i, j] = 2. * (precision * recall) / (precision + recall)
    return np.nansum([f[i][j] * len(_ytrue[_ytrue == i]) for i in xrange(nClassTrue) for j in xrange(nClassPred)]) / nSize
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号