def load_imagenet_model(name, *args, **kwargs):
# load model and state dict
model = getattr(torchvision.models, name)(*args, **kwargs)
state_dict = torch.utils.model_zoo.load_url(model_path[name])
# load state dict to model
load_state_dict(model, state_dict)
return model
评论列表
文章目录