def build_model(args):
if not hasattr(torchvision.models, args.model):
raise ValueError('Invalid model "%s"' % args.model)
if not 'resnet' in args.model:
raise ValueError('Feature extraction only supports ResNets')
cnn = getattr(torchvision.models, args.model)(pretrained=True)
layers = [
cnn.conv1,
cnn.bn1,
cnn.relu,
cnn.maxpool,
]
for i in range(args.model_stage):
name = 'layer%d' % (i + 1)
layers.append(getattr(cnn, name))
model = torch.nn.Sequential(*layers)
model.cuda()
model.eval()
return model
评论列表
文章目录