def test_vis(model, test_loader, model_path, threshlod,\
margin=1.0, is_cuda=True, output_dir='output',is_visualization=True):
if not model_path is None:
model.load_full_weights(model_path)
print('loaded model file: {:s}'.format(model_path))
if is_cuda:
model = model.cuda()
model.eval()
test_loss = AverageMeter()
accuracy = 0
num_p = 0
total_num = 0
batch_num = len(test_loader)
for batch_idx, (data_a, data_p, data_n,target, img_paths) in enumerate(test_loader):
#for batch_idx, (data_a, data_p, data_n, target) in enumerate(test_loader):
if is_cuda:
data_a = data_a.cuda()
data_p = data_p.cuda()
data_n = data_n.cuda()
target = target.cuda()
data_a = Variable(data_a, volatile=True)
data_p = Variable(data_p, volatile=True)
data_n = Variable(data_n, volatile=True)
target = Variable(target)
out_a = model(data_a)
out_p = model(data_p)
out_n = model(data_n)
loss = F.triplet_margin_loss(out_a,out_p,out_n, margin)
dist1 = F.pairwise_distance(out_a,out_p)
dist2 = F.pairwise_distance(out_a,out_n)
batch_size = data_a.size()[0]
pos_flag = (dist1 <= threshlod).float()
neg_flag = (dist2 > threshlod).float()
if is_visualization:
for k in torch.arange(0, batch_size):
k = int(k)
if pos_flag[k].data[0] == 0:
combine_and_save(img_paths[0][k], img_paths[1][k], dist1[k], output_dir, '1-1')
if neg_flag[k].data[0] == 0:
combine_and_save(img_paths[0][k], img_paths[2][k], dist2[k], output_dir, '1-0')
num = (pos_flag.sum() + neg_flag.sum()).data[0]
print('{:f}, {:f}, {:f}'.format(num, pos_flag.sum().data[0], neg_flag.sum().data[0]))
num_p += num
total_num += data_a.size()[0] * 2
print('num_p = {:f}, total = {:f}'.format(num_p, total_num))
print('dist1 = {:f}, dist2 = {:f}'.format(dist1[0].data[0], dist2[0].data[0]))
accuracy = num_p / total_num
return accuracy
评论列表
文章目录