def main ():
global args
args = parser.parse_args()
print('\nCUDA status: {}'.format(args.cuda))
print('\nLoad pretrained model on Imagenet')
model = pretrainedmodels.__dict__[args.arch](num_classes=1000, pretrained='imagenet')
model.eval()
if args.cuda:
model.cuda()
features_size = model.last_linear.in_features
model.last_linear = pretrainedmodels.utils.Identity() # Trick to get inputs (features) from last_linear
print('\nLoad datasets')
tf_img = pretrainedmodels.utils.TransformImage(model)
train_set = pretrainedmodels.datasets.Voc2007Classification(args.dir_datasets, 'train', transform=tf_img)
val_set = pretrainedmodels.datasets.Voc2007Classification(args.dir_datasets, 'val', transform=tf_img)
test_set = pretrainedmodels.datasets.Voc2007Classification(args.dir_datasets, 'test', transform=tf_img)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=2)
print('\nLoad features')
dir_features = os.path.join(args.dir_outputs, 'data/{}'.format(args.arch))
path_train_data = '{}/{}set.pth'.format(dir_features, 'train')
path_val_data = '{}/{}set.pth'.format(dir_features, 'val')
path_test_data = '{}/{}set.pth'.format(dir_features, 'test')
features = {}
targets = {}
features['train'], targets['train'] = extract_features_targets(model, features_size, train_loader, path_train_data, args.cuda)
features['val'], targets['val'] = extract_features_targets(model, features_size, val_loader, path_val_data, args.cuda)
features['test'], targets['test'] = extract_features_targets(model, features_size, test_loader, path_test_data, args.cuda)
features['trainval'] = torch.cat([features['train'], features['val']], 0)
targets['trainval'] = torch.cat([targets['train'], targets['val']], 0)
print('\nTrain Support Vector Machines')
if args.train_split == 'train' and args.test_split == 'val':
print('\nHyperparameters search: train multilabel classifiers (on-versus-all) on train/val')
elif args.train_split == 'trainval' and args.test_split == 'test':
print('\nEvaluation: train a multilabel classifier on trainval/test')
else:
raise ValueError('Trying to train on {} and eval on {}'.format(args.train_split, args.test_split))
train_multilabel(features, targets, train_set.classes, args.train_split, args.test_split, C=args.C)
voc2007_extract.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录