occlusion.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DeepLearning_PlantDiseases 作者: MarkoArsenovic 项目源码 文件源码
def load_defined_model(path, num_classes,name):
    model = models.__dict__[name](num_classes=num_classes)
    pretrained_state = torch.load(path)
    new_pretrained_state= OrderedDict()

    for k, v in pretrained_state['state_dict'].items():
        layer_name = k.replace("module.", "")
        new_pretrained_state[layer_name] = v

    #Diff
    diff = [s for s in diff_states(model.state_dict(), new_pretrained_state)]
    if(len(diff)!=0):
        print("Mismatch in these layers :", name, ":", [d[0] for d in diff])

    assert len(diff) == 0

    #Merge
    model.load_state_dict(new_pretrained_state)
    return model


#Load the model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号