def get_im2col_indices(x_shape, filter_shape, stride, pad):
BS, in_D, in_H, in_W = x_shape
f_H, f_W = filter_shape
pad_H, pad_W = pad
stride_H, stride_W = stride
out_H = int((in_H + 2*pad_H - f_H) / stride_W + 1)
out_W = int((in_W + 2*pad_W - f_W) / stride_W + 1)
i_col = np.repeat(np.arange(f_H), f_W)
i_col = np.tile(i_col, in_D).reshape(-1, 1)
i_row = stride_H * np.repeat(np.arange(out_H), out_W)
i = i_col + i_row #shape=(in_D*f_H*f_W,out_H*out_W)
j_col = np.tile(np.arange(f_W), f_H)
j_col = np.tile(j_col, in_D).reshape(-1, 1)
j_row = stride_W * np.tile(np.arange(out_W), out_H)
j = j_col + j_row #shape=(in_D*f_H*f_W,out_W*out_H)
c = np.repeat(np.arange(in_D), f_H * f_W).reshape(-1, 1) #shape=(in_D*f_H*f_W,1)
return (c, i, j)
评论列表
文章目录