spatial_max_pooling.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:PyFunt 作者: dnlcrl 项目源码 文件源码
def update_grad_input(self, x, grad_output, scale=1):
        x_cols = self.x_cols
        x_cols_argmax = self.x_cols_argmax
        dout = grad_output
        N, C, H, W = x.shape
        pool_height, pool_width = self.kW, self.kH
        stride = self.dW

        dout_reshaped = dout.transpose(2, 3, 0, 1).flatten()
        dx_cols = np.zeros_like(x_cols)
        dx_cols[x_cols_argmax, np.arange(dx_cols.shape[1])] = dout_reshaped
        dx = col2im_cython(dx_cols, N * C, 1, H, W, pool_height, pool_width,
                           padding=0, stride=stride)
        dx = dx.reshape(self.x_shape)
        self.grad_input = dx
        return self.grad_input
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号