def __init__(self, opt):
super().__init__()
self.opt = opt
if opt.netSpec == 'resnet101':
resnet = models.resnet101(pretrained=opt.pretrain)
elif opt.netSpec == 'resnet50':
resnet = models.resnet50(pretrained=opt.pretrain)
elif opt.netSpec == 'resnet34':
resnet = models.resnet34(pretrained=opt.pretrain)
self.conv1 = resnet.conv1
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
for m in self.modules():
if isinstance(m, nn.Conv2d):
# m.stride = 1
m.requires_grad = False
if isinstance(m, nn.BatchNorm2d):
m.requires_grad = False
self.layer5a = PSPDec(512, 128, (1,1))
self.layer5b = PSPDec(512, 128, (2,2))
self.layer5c = PSPDec(512, 128, (3,3))
self.layer5d = PSPDec(512, 128, (6,6))
self.final = nn.Sequential(
nn.Conv2d(512*2, 512, 3, padding=1, bias=False),
nn.BatchNorm2d(512, momentum=.95),
nn.ReLU(inplace=True),
nn.Dropout(.1),
nn.Conv2d(512, opt.numClasses, 1),
)
评论列表
文章目录