def get_idx_from_arg(a, arg, axis):
shp = a.shape
cp = np.cumprod(shp[::-1])[::-1]
if axis == len(shp) - 1:
m = 1
else:
m = cp[axis + 1]
n = cp[0] // cp[axis]
if m == 1:
return np.arange(n) * cp[axis] + arg.ravel()
return np.repeat(np.arange(n) * cp[axis], m) + np.tile(np.arange(m), n) + arg.ravel() * m
评论列表
文章目录