def expand(T, w):
if len(w) > 1: # X AND Z (im2col)
w = (1,) + w
T = view_as_windows(T, w+(1,)*(T.ndim-len(w))).squeeze(tuple(range(T.ndim+4, T.ndim*2)))
T = np.transpose (T, range(4) + range(T.ndim-4, T.ndim) + range(4, T.ndim-4))
T = np.squeeze (T, 4)
else: # Z ONLY (2nd-stage expansion)
sh = list(T.shape ); sh[1] = w[0]
st = list(T.strides); st[1] = 0
T = T if T.shape[1] == w[0] else as_strided(T, sh, st)
return T
评论列表
文章目录