def GetPretrainedModel(params, num_classes):
if params['model'] == 'resnet18':
model = models.resnet18(pretrained=True)
elif params['model'] == 'resnet34':
model = models.resnet34(pretrained=True)
elif params['model'] == 'resnet50':
model = models.resnet50(pretrained=True)
elif params['model'] == 'resnet101':
model = models.resnet101(pretrained=True)
elif params['model'] == 'resnet152':
model = models.resnet152(pretrained=True)
else:
raise ValueError('Unknown model type')
num_features = model.fc.in_features
model.fc = SigmoidLinear(num_features, num_classes)
return model
评论列表
文章目录