def get_model(name, n_classes):
model = _get_model_instance(name)
if name in ['frrnA', 'frrnB']:
model = model(n_classes, model_type=name[-1])
elif name in ['fcn32s', 'fcn16s', 'fcn8s']:
model = model(n_classes=n_classes)
vgg16 = models.vgg16(pretrained=True)
model.init_vgg16_params(vgg16)
elif name == 'segnet':
model = model(n_classes=n_classes,
is_unpooling=True)
vgg16 = models.vgg16(pretrained=True)
model.init_vgg16_params(vgg16)
elif name == 'unet':
model = model(n_classes=n_classes,
is_batchnorm=True,
in_channels=3,
is_deconv=True)
else:
model = model(n_classes=n_classes)
return model
评论列表
文章目录