spatial_convolution.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:PyFunt 作者: dnlcrl 项目源码 文件源码
def update_grad_input(self, input, grad_output, scale=1):
        x_shape, x_cols = self.x_shape, self.x_cols
        w = self.weight

        stride, pad = self.dW, self.padW

        N, C, H, W = x_shape
        F, _, HH, WW = w.shape
        _, _, out_h, out_w = grad_output.shape

        self.grad_bias[:] = np.sum(grad_output, axis=(0, 2, 3))[:]

        dout_reshaped = grad_output.transpose(1, 0, 2, 3).reshape(F, -1)
        self.grad_weight[:] = dout_reshaped.dot(x_cols.T).reshape(w.shape)[:]

        dx_cols = w.reshape(F, -1).T.dot(dout_reshaped)
        #dx_cols.shape = (C, HH, WW, N, out_h, out_w)
        # dx = col2im_6d_cython(dx_cols, N, C, H, W, HH, WW, pad, stride)
        dx = col2im_cython(dx_cols, N, C, H, W, HH, WW, pad, stride)
        self.grad_input = dx
        return dx
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号