splitgconv2d.py 文件源码

python
阅读 68 收藏 0 点赞 0 评论 0

项目:GrouPy 作者: tscohen 项目源码 文件源码
def __call__(self, x):

        # Apply a mask to the filters (optional)
        if self.filter_mask is not None:
            w, m = F.broadcast(self.W, Variable(self.filter_mask))
            w = w * m
            # w = self.W * Variable(self.filter_mask)
        else:
            w = self.W

        # Transform the filters
        # w.shape  == (out_channels, in_channels, input_stabilizer_size, ksize, ksize)
        # tw.shape == (out_channels, output_stabilizer_size, in_channels, input_stabilizer_size, ksize, ksize)
        tw = TransformGFilter(self.inds)(w)

        # Fold the transformed filters
        tw_shape = (self.out_channels * self.output_stabilizer_size,
                    self.in_channels * self.input_stabilizer_size,
                    self.ksize, self.ksize)
        tw = F.Reshape(tw_shape)(tw)

        # If flat_channels is False, we need to flatten the input feature maps to have a single 1d feature dimension.
        if not self.flat_channels:
            batch_size = x.data.shape[0]
            in_ny, in_nx = x.data.shape[-2:]
            x = F.reshape(x, (batch_size, self.in_channels * self.input_stabilizer_size, in_ny, in_nx))

        # Perform the 2D convolution
        y = F.convolution_2d(x, tw, b=None, stride=self.stride, pad=self.pad, use_cudnn=self.use_cudnn)

        # Unfold the output feature maps
        # We do this even if flat_channels is True, because we need to add the same bias to each G-feature map
        batch_size, _, ny_out, nx_out = y.data.shape
        y = F.reshape(y, (batch_size, self.out_channels, self.output_stabilizer_size, ny_out, nx_out))

        # Add a bias to each G-feature map
        if self.usebias:
            bb = F.Reshape((1, self.out_channels, 1, 1, 1))(self.b)
            y, b = F.broadcast(y, bb)
            y = y + b

        # Flatten feature channels if needed
        if self.flat_channels:
            n, nc, ng, nx, ny = y.data.shape
            y = F.reshape(y, (n, nc * ng, nx, ny))

        return y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号