def update_grad_input(self, x, grad_output, scale=1):
x_cols = self.x_cols
x_cols_argmax = self.x_cols_argmax
dout = grad_output
N, C, H, W = x.shape
pool_height, pool_width = self.kW, self.kH
stride = self.dW
dout_reshaped = dout.transpose(2, 3, 0, 1).flatten()
dx_cols = np.zeros_like(x_cols)
dx_cols[x_cols_argmax, np.arange(dx_cols.shape[1])] = dout_reshaped
dx = col2im_cython(dx_cols, N * C, 1, H, W, pool_height, pool_width,
padding=0, stride=stride)
dx = dx.reshape(self.x_shape)
self.grad_input = dx
return self.grad_input
评论列表
文章目录