spatial_average_pooling.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:PyFunt 作者: dnlcrl 项目源码 文件源码
def update_grad_input(self, x, grad_output, scale=1):
        x_cols = self.x_cols
        dout = grad_output
        N, C, H, W = self.x_shape
        pool_height, pool_width = self.kW, self.kH
        stride = self.dW
        pool_dim = pool_height * pool_width

        dout_reshaped = dout.transpose(2, 3, 0, 1).flatten()
        dx_cols = np.zeros_like(x_cols)
        dx_cols[:, np.arange(dx_cols.shape[1])] = 1. / pool_dim * dout_reshaped
        dx = col2im_cython(dx_cols, N * C, 1, H, W, pool_height, pool_width,
                           padding=0, stride=stride)

        self.grad_input = dx.reshape(self.x_shape)

        return self.grad_input
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号