def forward(self, x):
# conv & downsampling
down_sampled_fmaps = []
for i in range(self.n_stages-1):
x = self.down_convs[i](x)
x = self.max_pooling(x)
down_sampled_fmaps.insert(0, x)
# center convs
x = self.down_convs[self.n_stages-1](x)
x = self.up_convs[0](x)
# conv & upsampling
for i, down_sampled_fmap in enumerate(down_sampled_fmaps):
x = torch.cat([x, down_sampled_fmap], 1)
x = self.up_convs[i+1](x)
x = F.upsample(x, scale_factor=2, mode='bilinear')
return self.out_conv(x)
#x = self.out_conv(x)
#return x if self.out_conv.out_channels == 1 else F.relu(x)
networks.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录