def load_defined_model(name, num_classes):
model = models.__dict__[name](num_classes=num_classes)
#Densenets don't (yet) pass on num_classes, hack it in for 169
if name == 'densenet169':
model = torchvision.models.DenseNet(num_init_features=64, growth_rate=32, \
block_config=(6, 12, 32, 32), num_classes=num_classes)
pretrained_state = model_zoo.load_url(model_urls[name])
#Diff
diff = [s for s in diff_states(model.state_dict(), pretrained_state)]
print("Replacing the following state from initialized", name, ":", \
[d[0] for d in diff])
for name, value in diff:
pretrained_state[name] = value
assert len([s for s in diff_states(model.state_dict(), pretrained_state)]) == 0
#Merge
model.load_state_dict(pretrained_state)
return model, diff
python类__dict__()的实例源码
def get_cnn(self, arch, pretrained):
"""Load a pretrained CNN and parallelize over GPUs
"""
if pretrained:
print("=> using pre-trained model '{}'".format(arch))
model = models.__dict__[arch](pretrained=True)
else:
print("=> creating model '{}'".format(arch))
model = models.__dict__[arch]()
if arch.startswith('alexnet') or arch.startswith('vgg'):
model.features = nn.DataParallel(model.features)
model.cuda()
else:
model = nn.DataParallel(model).cuda()
return model
def load_defined_model(path, num_classes,name):
model = models.__dict__[name](num_classes=num_classes)
pretrained_state = torch.load(path)
new_pretrained_state= OrderedDict()
for k, v in pretrained_state['state_dict'].items():
layer_name = k.replace("module.", "")
new_pretrained_state[layer_name] = v
#Diff
diff = [s for s in diff_states(model.state_dict(), new_pretrained_state)]
if(len(diff)!=0):
print("Mismatch in these layers :", name, ":", [d[0] for d in diff])
assert len(diff) == 0
#Merge
model.load_state_dict(new_pretrained_state)
return model
#Load the model
def load_defined_model(path, num_classes,name):
model = models.__dict__[name](num_classes=num_classes)
pretrained_state = torch.load(path)
new_pretrained_state= OrderedDict()
for k, v in pretrained_state['state_dict'].items():
layer_name = k.replace("module.", "")
new_pretrained_state[layer_name] = v
#Diff
diff = [s for s in diff_states(model.state_dict(), new_pretrained_state)]
if(len(diff)!=0):
print("Mismatch in these layers :", name, ":", [d[0] for d in diff])
assert len(diff) == 0
#Merge
model.load_state_dict(new_pretrained_state)
return model
#Load the model
def _nn_forward_hook(self, module, input, output, name=''):
if type(output) is list:
self.blobs[name] = [o.data.clone() for o in output]
else:
self.blobs[name] = output.data.clone()
# @staticmethod
# def _load_model_config(model_def):
# if isinstance(model_def, torch.nn.Module):
#
# elif '.' not in os.path.basename(model_def):
# import torchvision.models as models
# if model_def not in models.__dict__:
# raise KeyError('Model {} does not exist in pytorchs model zoo.'.format(model_def))
# print('Loading model {} from pytorch model zoo'.format(model_def))
# return models.__dict__[model_def](pretrained=True)
# else:
# print('Loading model from {}'.format(model_def))
# if model_def.endswith('.t7'):
# return load_legacy_model(model_def)
# else:
# return torch.load(model_def)
#
#
# if type(model_cfg) == str:
# if not os.path.exists(model_cfg):
# try:
# class_ = getattr(applications, model_cfg)
# return class_(weights=model_weights)
# except AttributeError:
# available_mdls = [attr for attr in dir(applications) if callable(getattr(applications, attr))]
# raise ValueError('Could not load pretrained model with key {}. '
# 'Available models: {}'.format(model_cfg, ', '.join(available_mdls)))
#
# with open(model_cfg, 'r') as fileh:
# try:
# return model_from_json(fileh)
# except ValueError:
# pass
#
# try:
# return model_from_yaml(fileh)
# except ValueError:
# pass
#
# raise ValueError('Could not load model from configuration file {}. '
# 'Make sure the path is correct and the file format is yaml or json.'.format(model_cfg))
# elif type(model_cfg) == dict:
# return Model.from_config(model_cfg)
# elif type(model_cfg) == list:
# return Sequential.from_config(model_cfg)
#
# raise ValueError('Could not load model from configuration object of type {}.'.format(type(model_cfg)))