frrn.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:pytorch-semseg 作者: meetshah1995 项目源码 文件源码
def forward(self, x):

        # pass to initial conv
        x = self.conv1(x)

        # pass through residual units
        for i in range(3):
            x = self.up_residual_units[i](x)

        # divide stream
        y = x
        z = self.split_conv(x)

        prev_channels = 48
        # encoding
        for n_blocks, channels, scale in self.encoder_frru_specs:
            # maxpool bigger feature map
            y_pooled = F.max_pool2d(y, stride=2, kernel_size=2, padding=0)
            # pass through encoding FRRUs
            for block in range(n_blocks):
                key = '_'.join(map(str,['encoding_frru', n_blocks, channels, scale, block]))
                y, z = getattr(self, key)(y_pooled, z)
            prev_channels = channels

        # decoding
        for n_blocks, channels, scale in self.decoder_frru_specs:
            # bilinear upsample smaller feature map
            upsample_size = torch.Size([_s*2 for _s in y.size()[-2:]]) 
            y_upsampled = F.upsample(y, size=upsample_size, mode='bilinear')
            # pass through decoding FRRUs
            for block in range(n_blocks):
                key = '_'.join(map(str,['decoding_frru', n_blocks, channels, scale, block]))
                #print "Incoming FRRU Size: ", key, y_upsampled.shape, z.shape
                y, z = getattr(self, key)(y_upsampled, z)
                #print "Outgoing FRRU Size: ", key, y.shape, z.shape
            prev_channels = channels

        # merge streams
        x = torch.cat([F.upsample(y, scale_factor=2, mode='bilinear' ), z], dim=1)
        x = self.merge_conv(x)

        # pass through residual units
        for i in range(3):
            x = self.down_residual_units[i](x)

        # final 1x1 conv to get classification
        x = self.classif_conv(x)

        return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号