PSPNet.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:PytorchDL 作者: FredHuangBia 项目源码 文件源码
def __init__(self, opt):
        super().__init__()
        self.opt = opt

        if opt.netSpec == 'resnet101':
            resnet = models.resnet101(pretrained=opt.pretrain)
        elif opt.netSpec == 'resnet50':
            resnet = models.resnet50(pretrained=opt.pretrain)
        elif opt.netSpec == 'resnet34':
            resnet = models.resnet34(pretrained=opt.pretrain)

        self.conv1 = resnet.conv1
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # m.stride = 1
                m.requires_grad = False
            if isinstance(m, nn.BatchNorm2d):
                m.requires_grad = False

        self.layer5a = PSPDec(512, 128, (1,1))
        self.layer5b = PSPDec(512, 128, (2,2))
        self.layer5c = PSPDec(512, 128, (3,3))
        self.layer5d = PSPDec(512, 128, (6,6))

        self.final = nn.Sequential(
            nn.Conv2d(512*2, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=.95),
            nn.ReLU(inplace=True),
            nn.Dropout(.1),
            nn.Conv2d(512, opt.numClasses, 1),
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号