spatial_average_pooling.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:PyFunt 作者: dnlcrl 项目源码 文件源码
def update_output(self, x):
        N, C, H, W = x.shape
        pool_height, pool_width = self.kW, self.kH
        stride = self.dW

        assert (
            H - pool_height) % stride == 0 or H == pool_height, 'Invalid height'
        assert (
            W - pool_width) % stride == 0 or W == pool_width, 'Invalid width'

        out_height = int(np.floor((H - pool_height) / stride + 1))
        out_width = int(np.floor((W - pool_width) / stride + 1))

        x_split = x.reshape(N * C, 1, H, W)
        x_cols = im2col_cython(
            x_split, pool_height, pool_width, padding=0, stride=stride)
        x_cols_avg = np.mean(x_cols, axis=0)
        out = x_cols_avg.reshape(
            out_height, out_width, N, C).transpose(2, 3, 0, 1)

        self.x_shape = x.shape
        self.x_cols = x_cols
        self.output = out
        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号