def update_grad_input(self, input, grad_output, scale=1):
x_shape, x_cols = self.x_shape, self.x_cols
w = self.weight
stride, pad = self.dW, self.padW
N, C, H, W = x_shape
F, _, HH, WW = w.shape
_, _, out_h, out_w = grad_output.shape
self.grad_bias[:] = np.sum(grad_output, axis=(0, 2, 3))[:]
dout_reshaped = grad_output.transpose(1, 0, 2, 3).reshape(F, -1)
self.grad_weight[:] = dout_reshaped.dot(x_cols.T).reshape(w.shape)[:]
dx_cols = w.reshape(F, -1).T.dot(dout_reshaped)
#dx_cols.shape = (C, HH, WW, N, out_h, out_w)
# dx = col2im_6d_cython(dx_cols, N, C, H, W, HH, WW, pad, stride)
dx = col2im_cython(dx_cols, N, C, H, W, HH, WW, pad, stride)
self.grad_input = dx
return dx
评论列表
文章目录