def argmaxk_rows_opt1(arr, k=10, sort=False):
"""
Optimized implementation. When sort=False it is equal to argmaxk_rows_basic. When sort=True and k << arr.shape[1],
it is should be faster, because we argsort only subarray of k max elements from each row of arr (arr.shape[0] x k) instead of
the whole array arr (arr.shape[0] x arr.shape[1]).
"""
best_inds = np.argpartition(arr, kth=-k, axis=1)[:, -k:] # column indices of k max elements in each row (m x k)
if not sort:
return best_inds
# generate row indices corresponding to best_ids (just current row id in each row) (m x k)
rows = np.arange(best_inds.shape[0], dtype=np.intp)[:, np.newaxis].repeat(best_inds.shape[1], axis=1)
best_elems = arr[rows, best_inds] # select k max elements from each row using advanced indexing (m x k)
# indices which sort each row of best_elems in descending order (m x k)
best_elems_inds = np.argsort(best_elems, axis=1)[:, ::-1]
# reorder best_indices so that arr[i, sorted_best_inds[i,:]] will be sorted in descending order
sorted_best_inds = best_inds[rows, best_elems_inds]
return sorted_best_inds
评论列表
文章目录