def __init__(self, num_classes, input_size, pretrained=True, use_aux=True):
super(PSPNetDeform, self).__init__()
self.input_size = input_size
self.use_aux = use_aux
resnet = models.resnet101()
if pretrained:
resnet.load_state_dict(torch.load(res101_path))
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
for n, m in self.layer3.named_modules():
if 'conv2' in n:
m.padding = (1, 1)
m.stride = (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for n, m in self.layer4.named_modules():
if 'conv2' in n:
m.padding = (1, 1)
m.stride = (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for idx in range(len(self.layer3)):
self.layer3[idx].conv2 = Conv2dDeformable(self.layer3[idx].conv2)
for idx in range(len(self.layer4)):
self.layer4[idx].conv2 = Conv2dDeformable(self.layer4[idx].conv2)
self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
self.final = nn.Sequential(
nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512, momentum=.95),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(512, num_classes, kernel_size=1)
)
if use_aux:
self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
initialize_weights(self.aux_logits)
initialize_weights(self.ppm, self.final)
psp_net.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录