evaluation.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:SpindleNet 作者: yokattame 项目源码 文件源码
def main(args):
  PF, PL, GF, GL = _get_test_data(args.result_dir)
  D = pairwise_distances(GF, PF, metric=args.method, n_jobs=-2)

  gallery_labels_set = np.unique(GL)

  for label in PL:
    if label not in gallery_labels_set:
      print 'Probe-id is out of Gallery-id sets.'

  Times = 100
  k = 20

  res = np.zeros(k)

  gallery_labels_map = [[] for i in xrange(gallery_labels_set.size)]
  for i, g in enumerate(GL):
    gallery_labels_map[g].append(i)

  for __ in xrange(Times):
    # Randomly select one gallery sample per label selected
    newD = np.zeros((gallery_labels_set.size, PL.size))
    for i, g in enumerate(gallery_labels_set):
      j = np.random.choice(gallery_labels_map[g])
      newD[i, :] = D[j, :]
    # Compute CMC
    res += _cmc_core(newD, gallery_labels_set, PL, k)
  res /= Times

  for topk in [1, 5, 10, 20]:
    print "{:8}{:8.1%}".format('top-' + str(topk), res[topk - 1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号