voc2007_extract.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pretrained-models.pytorch 作者: Cadene 项目源码 文件源码
def main ():
    global args
    args = parser.parse_args()
    print('\nCUDA status: {}'.format(args.cuda))

    print('\nLoad pretrained model on Imagenet')
    model = pretrainedmodels.__dict__[args.arch](num_classes=1000, pretrained='imagenet')
    model.eval()
    if args.cuda:
        model.cuda()

    features_size = model.last_linear.in_features
    model.last_linear = pretrainedmodels.utils.Identity() # Trick to get inputs (features) from last_linear

    print('\nLoad datasets')
    tf_img = pretrainedmodels.utils.TransformImage(model)
    train_set = pretrainedmodels.datasets.Voc2007Classification(args.dir_datasets, 'train', transform=tf_img)
    val_set = pretrainedmodels.datasets.Voc2007Classification(args.dir_datasets, 'val', transform=tf_img)
    test_set = pretrainedmodels.datasets.Voc2007Classification(args.dir_datasets, 'test', transform=tf_img)

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=2)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=2)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=2)

    print('\nLoad features')
    dir_features = os.path.join(args.dir_outputs, 'data/{}'.format(args.arch))
    path_train_data = '{}/{}set.pth'.format(dir_features, 'train')
    path_val_data = '{}/{}set.pth'.format(dir_features, 'val')
    path_test_data = '{}/{}set.pth'.format(dir_features, 'test')

    features = {}
    targets = {}
    features['train'], targets['train'] = extract_features_targets(model, features_size, train_loader, path_train_data, args.cuda)
    features['val'], targets['val'] = extract_features_targets(model, features_size, val_loader, path_val_data, args.cuda)
    features['test'], targets['test'] = extract_features_targets(model, features_size, test_loader, path_test_data, args.cuda)
    features['trainval'] = torch.cat([features['train'], features['val']], 0)
    targets['trainval'] = torch.cat([targets['train'], targets['val']], 0)

    print('\nTrain Support Vector Machines')
    if args.train_split == 'train' and args.test_split == 'val':
        print('\nHyperparameters search: train multilabel classifiers (on-versus-all) on train/val')
    elif args.train_split == 'trainval' and args.test_split == 'test':
        print('\nEvaluation: train a multilabel classifier on trainval/test')
    else:
        raise ValueError('Trying to train on {} and eval on {}'.format(args.train_split, args.test_split))

    train_multilabel(features, targets, train_set.classes, args.train_split, args.test_split, C=args.C)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号