def load_defined_model(path, num_classes,name):
model = models.__dict__[name](num_classes=num_classes)
pretrained_state = torch.load(path)
new_pretrained_state= OrderedDict()
for k, v in pretrained_state['state_dict'].items():
layer_name = k.replace("module.", "")
new_pretrained_state[layer_name] = v
#Diff
diff = [s for s in diff_states(model.state_dict(), new_pretrained_state)]
if(len(diff)!=0):
print("Mismatch in these layers :", name, ":", [d[0] for d in diff])
assert len(diff) == 0
#Merge
model.load_state_dict(new_pretrained_state)
return model
#Load the model
saliency.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录