def retrieval_perlabel(X_train, Y_train, X_test, Y_test, fractions=[0.01, 0.5, 1.0]):
X_train = unitmatrix(X_train) # normalize
X_test = unitmatrix(X_test)
score = X_test.dot(X_train.T)
precisions = defaultdict(dict)
label_counter = Counter(Y_test.tolist())
for idx in range(len(X_test)):
retrieval_idx = score[idx].argsort()[::-1]
for fr in fractions:
ntop = int(fr * len(X_train))
pr = float(len([i for i in retrieval_idx[:ntop] if Y_train[i] == Y_test[idx]])) / ntop
try:
precisions[fr][Y_test[idx]] += pr
except:
precisions[fr][Y_test[idx]] = pr
new_pr = {}
for fr, val in precisions.iteritems():
avg_pr = 0.
for label, pr in val.iteritems():
avg_pr += pr / label_counter[label]
new_pr[fr] = avg_pr / len(label_counter)
return sorted(new_pr.items(), key=lambda d:d[0])
评论列表
文章目录