nn_finetune_resnet_101.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:KagglePlanetPytorch 作者: Mctigger 项目源码 文件源码
def generate_model():
    class MyModel(nn.Module):
        def __init__(self, pretrained_model):
            super(MyModel, self).__init__()
            self.pretrained_model = pretrained_model
            self.layer1 = pretrained_model.layer1
            self.layer2 = pretrained_model.layer2
            self.layer3 = pretrained_model.layer3
            self.layer4 = pretrained_model.layer4

            pretrained_model.avgpool = nn.AvgPool2d(8)
            classifier = [
                nn.Linear(pretrained_model.fc.in_features, 17),
            ]

            self.classifier = nn.Sequential(*classifier)
            pretrained_model.fc = self.classifier

        def forward(self, x):
            return F.sigmoid(self.pretrained_model(x))

    return MyModel(torchvision.models.resnet101(pretrained=True))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号