def retrieve2file(self, out_file, numn=0, nump=0):
"""highly customised for hardsel"""
ids = self.retrieve()
ret_labels = self.label_src[ids]
rel = ret_labels == self.label_q
#include/exclude the relevant in hard pos/neg selection
pos = ids[rel].reshape([rel.shape[0],-1])
pos = np.fliplr(pos) #hard positive
neg = ids[~rel].reshape([rel.shape[0],-1]) #hard negative
if nump > 0 and nump < pos.shape[1]:
pos = pos[:,0:nump]
if numn > 0 and numn < neg.shape[1]:
neg = neg[:,0:numn]
if out_file.endswith('.npz'):
np.savez(out_file, pos = pos, neg = neg)
P = np.cumsum(rel,axis=1) / np.arange(1,rel.shape[1]+1,dtype=np.float32)[None,...]
AP = np.sum(P*rel,axis=1) / (rel.sum(axis=1) + np.finfo(np.float32).eps)
mAP = AP.mean()
return mAP
评论列表
文章目录