def update_output(self, x):
N, C, H, W = x.shape
pool_height, pool_width = self.kW, self.kH
stride = self.dW
assert (
H - pool_height) % stride == 0 or H == pool_height, 'Invalid height'
assert (
W - pool_width) % stride == 0 or W == pool_width, 'Invalid width'
out_height = int(np.floor((H - pool_height) / stride + 1))
out_width = int(np.floor((W - pool_width) / stride + 1))
x_split = x.reshape(N * C, 1, H, W)
x_cols = im2col_cython(
x_split, pool_height, pool_width, padding=0, stride=stride)
x_cols_avg = np.mean(x_cols, axis=0)
out = x_cols_avg.reshape(
out_height, out_width, N, C).transpose(2, 3, 0, 1)
self.x_shape = x.shape
self.x_cols = x_cols
self.output = out
return self.output
评论列表
文章目录