def forward(self, x):
avg = F.avg_pool2d(x,kernel_size=7, stride=1, padding=3)
x1_1 = torch.cat([x,avg],1)
x1_1 = F.relu(self.conv1_1(x1_1))
x1_2 = F.avg_pool2d(x1_1,kernel_size=2, stride=2)
x1_2 = F.relu(self.conv1_2(x1_2))
x1_4 = F.avg_pool2d(x1_2,kernel_size=2, stride=2)
x1_4 = F.relu(self.conv1_3(x1_4))
x1_2_ = F.upsample_nearest(x1_4, scale_factor=2)
x1_2 = torch.cat([x1_2,x1_2_],1)
x1_2 = F.relu(self.conv1_4(x1_2))
x1_1_ = F.upsample_nearest(x1_2, scale_factor=2)
x1_1 = torch.cat([x1_1,x1_1_],1)
px = F.relu(self.conv1_5(x1_1))
px = torch.cat([px,px,px],1)
px = 1-px/16
return px*x+(1-px)*avg
评论列表
文章目录