def best_test(model, _loader, model_path=None, is_cuda=True):
if not model_path is None:
model.load_full_weights(model_path)
print('loaded model file: {:s}'.format(model_path))
if is_cuda:
model = model.cuda()
model.eval()
total_num = 0
batch_num = len(_loader)
for batch_idx, (data_a, data_p, data_n,target) in enumerate(_loader):
if is_cuda:
data_a = data_a.cuda()
data_p = data_p.cuda()
data_n = data_n.cuda()
target = target.cuda()
data_a = Variable(data_a, volatile=True)
data_p = Variable(data_p, volatile=True)
data_n = Variable(data_n, volatile=True)
target = Variable(target)
out_a = model(data_a)
out_p = model(data_p)
out_n = model(data_n)
current_d_a_p = F.pairwise_distance(out_a,out_p)
current_d_a_n = F.pairwise_distance(out_a,out_n)
if batch_idx == 0:
d_a_p = current_d_a_p
d_a_n = current_d_a_n
else:
d_a_p = torch.cat((d_a_p, current_d_a_p), 0)
d_a_n = torch.cat((d_a_n, current_d_a_n), 0)
total_num += 2*data_a.size()[0]
mean_d_a_p = d_a_p.mean().data[0]
mean_d_a_n = d_a_n.mean().data[0]
start = min(mean_d_a_n, mean_d_a_p)
end = max(mean_d_a_n, mean_d_a_p)
best_thre = 0
best_num = 0
thre_step = 0.05
for val in torch.arange(start, end+thre_step, thre_step):
num = (((d_a_p <= val).float()).sum() + (d_a_n > val).float().sum()).data[0]
#print(num, val)
if num > best_num:
best_num = num
best_thre = val
return best_thre, best_num/total_num, mean_d_a_p, mean_d_a_n
评论列表
文章目录