如何获取numpy数组中重复元素的所有索引的列表?

发布于 2021-01-29 18:27:27

我正在尝试获取numpy数组中所有重复元素的索引,但是我目前发现的解决方案对于大型(>
20000个元素)输入数组(大约需要9秒钟的时间),实际上效率很低。这个想法很简单:

  1. records_array是一个时间戳(datetime)的numpy数组,我们要从中提取重复时间戳的索引

  2. time_array 是一个numpy数组,其中包含在中重复的所有时间戳 records_array

  3. records是一个django QuerySet(可以轻松转换为列表),其中包含一些Record对象。我们要创建一个由Record的tagId属性的所有可能组合形成的对的列表,这些对应该是从中找到的重复时间戳所对应的records_array

这是我目前可用的(但效率低下)代码:

tag_couples = [];
for t in time_array:
    users_inter = np.nonzero(records_array == t)[0] # Get all repeated timestamps in records_array for time t
    l = [str(records[i].tagId) for i in users_inter] # Create a temporary list containing all tagIds recorded at time t
    if l.count(l[0]) != len(l): #remove tuples formed by the first tag repeated
        tag_couples +=[x for x in itertools.combinations(list(set(l)),2)] # Remove duplicates with list(set(l)) and append all possible couple combinations to tag_couples

我很确定可以通过使用Numpy来优化此方法,但是我找不到不使用for循环就可以records_array与的每个元素进行比较的方法time_array(这不能仅通过使用来进行比较==,因为它们都是数组)。

关注者
0
被浏览
232
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    numpy的向量化解法,作用于unique()

    import numpy as np
    
    # create a test array
    records_array = np.array([1, 2, 3, 1, 1, 3, 4, 3, 2])
    
    # creates an array of indices, sorted by unique element
    idx_sort = np.argsort(records_array)
    
    # sorts records array so all unique elements are together 
    sorted_records_array = records_array[idx_sort]
    
    # returns the unique values, the index of the first occurrence of a value, and the count for each element
    vals, idx_start, count = np.unique(sorted_records_array, return_counts=True, return_index=True)
    
    # splits the indices into separate arrays
    res = np.split(idx_sort, idx_start[1:])
    
    #filter them with respect to their size, keeping only items occurring more than once
    vals = vals[count > 1]
    res = filter(lambda x: x.size > 1, res)
    

    下面的代码是原始答案,它需要使用numpy广播和调用unique两次的更多内存:

    records_array = array([1, 2, 3, 1, 1, 3, 4, 3, 2])
    vals, inverse, count = unique(records_array, return_inverse=True,
                                  return_counts=True)
    
    idx_vals_repeated = where(count > 1)[0]
    vals_repeated = vals[idx_vals_repeated]
    
    rows, cols = where(inverse == idx_vals_repeated[:, newaxis])
    _, inverse_rows = unique(rows, return_index=True)
    res = split(cols, inverse_rows[1:])
    

    符合预期 res = [array([0, 3, 4]), array([1, 8]), array([2, 5, 7])]



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看