def load_defined_model(name, num_classes):
model = models.__dict__[name](num_classes=num_classes)
#Densenets don't (yet) pass on num_classes, hack it in for 169
if name == 'densenet169':
model = torchvision.models.DenseNet(num_init_features=64, growth_rate=32, \
block_config=(6, 12, 32, 32), num_classes=num_classes)
pretrained_state = model_zoo.load_url(model_urls[name])
#Diff
diff = [s for s in diff_states(model.state_dict(), pretrained_state)]
print("Replacing the following state from initialized", name, ":", \
[d[0] for d in diff])
for name, value in diff:
pretrained_state[name] = value
assert len([s for s in diff_states(model.state_dict(), pretrained_state)]) == 0
#Merge
model.load_state_dict(pretrained_state)
return model, diff
train.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录