def forward(self, inputs):
d0 = self.down0(inputs)
d1 = self.down1(d0)
d2 = self.down2(F.max_pool2d(d1, kernel_size=2, stride=2))
d3 = self.down3(F.max_pool2d(d2, kernel_size=2, stride=2))
d4 = self.down4(F.max_pool2d(d3, kernel_size=2, stride=2))
d5 = self.down5(F.max_pool2d(d4, kernel_size=2, stride=2))
d6 = self.down6(F.max_pool2d(d5, kernel_size=2, stride=2))
out = self.center(F.max_pool2d(d6, kernel_size=2, stride=2))
out = self.up6(
torch.cat([F.upsample(out, scale_factor=2, mode='bilinear'), d6], dim=1))
out = self.up5(
torch.cat([F.upsample(out, scale_factor=2, mode='bilinear'), d5], dim=1))
out = self.up4(
torch.cat([F.upsample(out, scale_factor=2, mode='bilinear'), d4], dim=1))
out = self.up3(
torch.cat([F.upsample(out, scale_factor=2, mode='bilinear'), d3], dim=1))
out = self.up2(
torch.cat([F.upsample(out, scale_factor=2, mode='bilinear'), d2], dim=1))
out = self.up1(
torch.cat([F.upsample(out, scale_factor=2, mode='bilinear'), d1], dim=1))
out = self.f1(torch.cat([out, d0], dim=1))
out = self.f2(torch.cat([out, inputs], dim=1))
out = self.out(out)
out = out.squeeze(1) # remove logits dim
return out
评论列表
文章目录