def validate(args):
# Setup Dataloader
data_loader = get_loader(args.dataset)
data_path = get_data_path(args.dataset)
loader = data_loader(data_path, split=args.split, is_transform=True, img_size=(args.img_rows, args.img_cols))
n_classes = loader.n_classes
valloader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=4)
running_metrics = runningScore(n_classes)
# Setup Model
model = get_model(args.model_path[:args.model_path.find('_')], n_classes)
state = convert_state_dict(torch.load(args.model_path)['model_state'])
model.load_state_dict(state)
model.eval()
for i, (images, labels) in tqdm(enumerate(valloader)):
model.cuda()
images = Variable(images.cuda(), volatile=True)
labels = Variable(labels.cuda(), volatile=True)
outputs = model(images)
pred = outputs.data.max(1)[1].cpu().numpy()
gt = labels.data.cpu().numpy()
running_metrics.update(gt, pred)
score, class_iou = running_metrics.get_scores()
for k, v in score.items():
print(k, v)
for i in range(n_classes):
print(i, class_iou[i])
评论列表
文章目录