def build_cnn(args, dtype):
if not hasattr(torchvision.models, args.cnn_model):
raise ValueError('Invalid model "%s"' % args.cnn_model)
if not 'resnet' in args.cnn_model:
raise ValueError('Feature extraction only supports ResNets')
whole_cnn = getattr(torchvision.models, args.cnn_model)(pretrained=True)
layers = [
whole_cnn.conv1,
whole_cnn.bn1,
whole_cnn.relu,
whole_cnn.maxpool,
]
for i in range(args.cnn_model_stage):
name = 'layer%d' % (i + 1)
layers.append(getattr(whole_cnn, name))
cnn = torch.nn.Sequential(*layers)
cnn.type(dtype)
cnn.eval()
return cnn
评论列表
文章目录