def forward(self, mid_input, global_input):
w = mid_input.size()[2]
h = mid_input.size()[3]
global_input = global_input.unsqueeze(2).unsqueeze(2).expand_as(mid_input)
fusion_layer = torch.cat((mid_input, global_input), 1)
fusion_layer = fusion_layer.permute(2, 3, 0, 1).contiguous()
fusion_layer = fusion_layer.view(-1, 512)
fusion_layer = self.bn1(self.fc1(fusion_layer))
fusion_layer = fusion_layer.view(w, h, -1, 256)
x = fusion_layer.permute(2, 3, 0, 1).contiguous()
x = F.relu(self.bn2(self.conv1(x)))
x = self.upsample(x)
x = F.relu(self.bn3(self.conv2(x)))
x = F.relu(self.bn4(self.conv3(x)))
x = self.upsample(x)
x = F.sigmoid(self.bn5(self.conv4(x)))
x = self.upsample(self.conv5(x))
return x
评论列表
文章目录