im2col.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:numpy_cnn 作者: Ryanshuai 项目源码 文件源码
def get_im2col_indices(x_shape, filter_shape, stride, pad):
    BS, in_D, in_H, in_W = x_shape
    f_H, f_W = filter_shape
    pad_H, pad_W = pad
    stride_H, stride_W = stride

    out_H = int((in_H + 2*pad_H - f_H) / stride_W + 1)
    out_W = int((in_W + 2*pad_W - f_W) / stride_W + 1)

    i_col = np.repeat(np.arange(f_H), f_W)
    i_col = np.tile(i_col, in_D).reshape(-1, 1)
    i_row = stride_H * np.repeat(np.arange(out_H), out_W)
    i = i_col + i_row #shape=(in_D*f_H*f_W,out_H*out_W)

    j_col = np.tile(np.arange(f_W), f_H)
    j_col = np.tile(j_col, in_D).reshape(-1, 1)
    j_row = stride_W * np.tile(np.arange(out_W), out_H)
    j = j_col + j_row #shape=(in_D*f_H*f_W,out_W*out_H)

    c = np.repeat(np.arange(in_D), f_H * f_W).reshape(-1, 1) #shape=(in_D*f_H*f_W,1)

    return (c, i, j)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号