快速组合,无需替换数组-NumPy / Python

发布于 2021-01-29 15:00:33

我从一维数组生成有效的成对组合之后。如果n> 1000,Itertools效率太低

E.g. [1, 2, 3, 4]

magic code...

Out[2]:
array([[1, 2],
       [1, 3],
       [1, 4],
       [2, 3],
       [2, 4],
       [3, 4]])

最近的事情在这里

关注者
0
被浏览
69
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    I.成对组合

    一种方法是numba获取内存,从而提高性能-

    from numba import njit
    
    @njit
    def pairwise_combs_numba(a):
        n = len(a)
        L = n*(n-1)//2
        out = np.empty((L,2),dtype=a.dtype)
        iterID = 0
        for i in range(n):
            for j in range(i+1,n):
                out[iterID,0] = a[i]
                out[iterID,1] = a[j]
                iterID += 1
        return out
    

    另一个基于NumPy的np.broadcast_to控件将用于获取网格视图,然后进行遮罩-

    def pairwise_combs_mask(a):
        n = len(a)
        L = n*(n-1)//2
        out = np.empty((L,2),dtype=a.dtype)
        m = ~np.tri(len(a),dtype=bool)
        out[:,0] = np.broadcast_to(a[:,None],(n,n))[m]
        out[:,1] = np.broadcast_to(a,(n,n))[m]
        return out
    

    二。三联体组合

    我们将扩展相同的方法,以使自己成为三元组合-

    @njit
    def triplet_combs_numba(a):
        n = len(a)
        L = n*(n-1)*(n-2)//6
        out = np.empty((L,3),dtype=a.dtype)
        iterID = 0
        for i in range(n):
            for j in range(i+1,n):
                for k in range(j+1,n):
                    out[iterID,0] = a[i]
                    out[iterID,1] = a[j]
                    out[iterID,2] = a[k]
                    iterID += 1
        return out
    
    def triplet_combs_mask(a):
        n = len(a)
        L = n*(n-1)*(n-2)//6
        out = np.empty((L,3),dtype=a.dtype)
        r = np.arange(n)
        m = (r[:,None,None]<r[:,None]) & (r[:,None]<r)
        out[:,0] = np.broadcast_to(a[:,None,None],(n,n,n))[m]
        out[:,1] = np.broadcast_to(a[None,:,None],(n,n,n))[m]
        out[:,2] = np.broadcast_to(a[None,None,:],(n,n,n))[m]
        return out
    

    高阶组合将同样扩展。

    样品运行-

    In [54]: a = np.array([3,9,4,1,7])
    
    In [55]: pairwise_combs_numba(a)
    Out[55]: 
    array([[3, 9],
           [3, 4],
           [3, 1],
           [3, 7],
           [9, 4],
           [9, 1],
           [9, 7],
           [4, 1],
           [4, 7],
           [1, 7]])
    
    In [56]: triplet_combs_numba(a)
    Out[56]: 
    array([[3, 9, 4],
           [3, 9, 1],
           [3, 9, 7],
           [3, 4, 1],
           [3, 4, 7],
           [3, 1, 7],
           [9, 4, 1],
           [9, 4, 7],
           [9, 1, 7],
           [4, 1, 7]])
    

    时间(包括Python的内置-
    itertools.combinations)-

    In [68]: a = np.random.rand(4000)
    
    In [69]: %timeit pairwise_combs_numba(a)
        ...: %timeit pairwise_combs_mask(a)
        ...: %timeit list(itertools.combinations(a, 2))
    10 loops, best of 3: 52.2 ms per loop
    10 loops, best of 3: 146 ms per loop
    1 loop, best of 3: 597 ms per loop
    
    In [70]: a = np.random.rand(400)
    
    In [71]: %timeit triplet_combs_numba(a)
        ...: %timeit triplet_combs_mask(a)
        ...: %timeit list(itertools.combinations(a, 3))
    10 loops, best of 3: 98.5 ms per loop
    1 loop, best of 3: 352 ms per loop
    1 loop, best of 3: 795 ms per loop
    


知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看