retrieval.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:KATE 作者: hugochan 项目源码 文件源码
def retrieval_perlabel(X_train, Y_train, X_test, Y_test, fractions=[0.01, 0.5, 1.0]):
    X_train = unitmatrix(X_train) # normalize
    X_test = unitmatrix(X_test)
    score = X_test.dot(X_train.T)
    precisions = defaultdict(dict)
    label_counter = Counter(Y_test.tolist())

    for idx in range(len(X_test)):
        retrieval_idx = score[idx].argsort()[::-1]
        for fr in fractions:
            ntop = int(fr * len(X_train))
            pr = float(len([i for i in retrieval_idx[:ntop] if Y_train[i] == Y_test[idx]])) / ntop
            try:
                precisions[fr][Y_test[idx]] += pr
            except:
                precisions[fr][Y_test[idx]] = pr
    new_pr = {}
    for fr, val in precisions.iteritems():
        avg_pr = 0.
        for label, pr in val.iteritems():
            avg_pr += pr / label_counter[label]
        new_pr[fr] = avg_pr / len(label_counter)

    return sorted(new_pr.items(), key=lambda d:d[0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号