nn_finetune_densenet_121.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:KagglePlanetPytorch 作者: Mctigger 项目源码 文件源码
def generate_model():
    class DenseModel(nn.Module):
        def __init__(self, pretrained_model):
            super(DenseModel, self).__init__()
            self.classifier = nn.Linear(pretrained_model.classifier.in_features, 17)

            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal(m.weight)
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
                elif isinstance(m, nn.Linear):
                    m.bias.data.zero_()

            self.features = pretrained_model.features
            self.dense1 = pretrained_model.features._modules['denseblock1']
            self.dense2 = pretrained_model.features._modules['denseblock2']
            self.dense3 = pretrained_model.features._modules['denseblock3']
            self.dense4 = pretrained_model.features._modules['denseblock4']

        def forward(self, x):
            features = self.features(x)
            out = F.relu(features, inplace=True)
            out = F.avg_pool2d(out, kernel_size=8).view(features.size(0), -1)
            out = F.sigmoid(self.classifier(out))
            return out

    return DenseModel(torchvision.models.densenet121(pretrained=True))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号