def argmin_n(m, n):
best_values = []
best_index = []
max_value_heap = []
for index, value in np.ndenumerate(m):
if len(best_values) == n:
if -1 * value < max_value_heap[0][0]:
# value is larger than the largest value
# and the list is at capacity
continue
_, pos = heapq.heappop(max_value_heap)
best_values[pos] = value
best_index[pos] = index
heapq.heappush(max_value_heap, (-1 * value, pos))
else:
heapq.heappush(max_value_heap, (-1 * value, len(best_values)))
best_values.append(value)
best_index.append(index)
pos, best_values = zip(*sorted(enumerate(best_values), key=lambda e: e[1]))
best_index = [best_index[i] for i in pos]
return best_index
评论列表
文章目录