def test(model, test_loader, epoch, margin, threshlod, is_cuda=True, log_interval=1000):
model.eval()
test_loss = AverageMeter()
accuracy = 0
num_p = 0
total_num = 0
batch_num = len(test_loader)
for batch_idx, (data_a, data_p, data_n,target) in enumerate(test_loader):
if is_cuda:
data_a = data_a.cuda()
data_p = data_p.cuda()
data_n = data_n.cuda()
target = target.cuda()
data_a = Variable(data_a, volatile=True)
data_p = Variable(data_p, volatile=True)
data_n = Variable(data_n, volatile=True)
target = Variable(target)
out_a = model(data_a)
out_p = model(data_p)
out_n = model(data_n)
loss = F.triplet_margin_loss(out_a,out_p,out_n, margin)
dist1 = F.pairwise_distance(out_a,out_p)
dist2 = F.pairwise_distance(out_a,out_n)
num = ((dist1 < threshlod).float().sum() + (dist2 > threshlod).float().sum()).data[0]
num_p += num
num_p = 1.0 * num_p
total_num += data_a.size()[0] * 2
#print('num--num_p -- total', num, num_p , total_num)
test_loss.update(loss.data[0])
if (batch_idx + 1) % log_interval == 0:
accuracy_tmp = num_p / total_num
print('Test- Epoch {:04d}\tbatch:{:06d}/{:06d}\tAccuracy:{:.04f}\tloss:{:06f}'\
.format(epoch, batch_idx+1, batch_num, accuracy_tmp, test_loss.avg))
test_loss.reset()
accuracy = num_p / total_num
return accuracy
评论列表
文章目录