def load_state_dict(model, model_urls, model_root):
from torch.utils import model_zoo
from torch import nn
import re
from collections import OrderedDict
own_state_old = model.state_dict()
own_state = OrderedDict() # remove all 'group' string
for k, v in own_state_old.items():
k = re.sub('group\d+\.', '', k)
own_state[k] = v
state_dict = model_zoo.load_url(model_urls, model_root)
for name, param in state_dict.items():
if name not in own_state:
print(own_state.keys())
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if isinstance(param, nn.Parameter):
# backwards compatibility for serialized parameters
param = param.data
own_state[name].copy_(param)
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
评论列表
文章目录