快速组合,无需替换数组-NumPy / Python
发布于 2021-01-29 15:00:33
我从一维数组生成有效的成对组合之后。如果n> 1000,Itertools效率太低
E.g. [1, 2, 3, 4]
magic code...
Out[2]:
array([[1, 2],
[1, 3],
[1, 4],
[2, 3],
[2, 4],
[3, 4]])
最近的事情在这里。
关注者
0
被浏览
69
1 个回答
-
I.成对组合
一种方法是
numba
获取内存,从而提高性能-from numba import njit @njit def pairwise_combs_numba(a): n = len(a) L = n*(n-1)//2 out = np.empty((L,2),dtype=a.dtype) iterID = 0 for i in range(n): for j in range(i+1,n): out[iterID,0] = a[i] out[iterID,1] = a[j] iterID += 1 return out
另一个基于NumPy的
np.broadcast_to
控件将用于获取网格视图,然后进行遮罩-def pairwise_combs_mask(a): n = len(a) L = n*(n-1)//2 out = np.empty((L,2),dtype=a.dtype) m = ~np.tri(len(a),dtype=bool) out[:,0] = np.broadcast_to(a[:,None],(n,n))[m] out[:,1] = np.broadcast_to(a,(n,n))[m] return out
二。三联体组合
我们将扩展相同的方法,以使自己成为三元组合-
@njit def triplet_combs_numba(a): n = len(a) L = n*(n-1)*(n-2)//6 out = np.empty((L,3),dtype=a.dtype) iterID = 0 for i in range(n): for j in range(i+1,n): for k in range(j+1,n): out[iterID,0] = a[i] out[iterID,1] = a[j] out[iterID,2] = a[k] iterID += 1 return out def triplet_combs_mask(a): n = len(a) L = n*(n-1)*(n-2)//6 out = np.empty((L,3),dtype=a.dtype) r = np.arange(n) m = (r[:,None,None]<r[:,None]) & (r[:,None]<r) out[:,0] = np.broadcast_to(a[:,None,None],(n,n,n))[m] out[:,1] = np.broadcast_to(a[None,:,None],(n,n,n))[m] out[:,2] = np.broadcast_to(a[None,None,:],(n,n,n))[m] return out
高阶组合将同样扩展。
样品运行-
In [54]: a = np.array([3,9,4,1,7]) In [55]: pairwise_combs_numba(a) Out[55]: array([[3, 9], [3, 4], [3, 1], [3, 7], [9, 4], [9, 1], [9, 7], [4, 1], [4, 7], [1, 7]]) In [56]: triplet_combs_numba(a) Out[56]: array([[3, 9, 4], [3, 9, 1], [3, 9, 7], [3, 4, 1], [3, 4, 7], [3, 1, 7], [9, 4, 1], [9, 4, 7], [9, 1, 7], [4, 1, 7]])
时间(包括Python的内置-
itertools.combinations
)-In [68]: a = np.random.rand(4000) In [69]: %timeit pairwise_combs_numba(a) ...: %timeit pairwise_combs_mask(a) ...: %timeit list(itertools.combinations(a, 2)) 10 loops, best of 3: 52.2 ms per loop 10 loops, best of 3: 146 ms per loop 1 loop, best of 3: 597 ms per loop In [70]: a = np.random.rand(400) In [71]: %timeit triplet_combs_numba(a) ...: %timeit triplet_combs_mask(a) ...: %timeit list(itertools.combinations(a, 3)) 10 loops, best of 3: 98.5 ms per loop 1 loop, best of 3: 352 ms per loop 1 loop, best of 3: 795 ms per loop