def bottom_top_k_along_row(arr, k, ordered=True):
""" bottom and top k of a 2d np.array, along the rows
http://stackoverflow.com/questions/6910641/how-to-get-indices-of-n-maximum-values-in-a-numpy-array/18691983
"""
assert k>0, "bottom_top_k_along_row/column() requires k>0."
rows = arr.shape[0]
if ordered:
tmp = np.argsort(arr, axis=1)
idx_bot = tmp[:, :k]
idx_top = tmp[:,-k:]
else:
idx_bot = np.argpartition(arr, k, axis=1)[:,:k]
idx_top = np.argpartition(arr, -k, axis=1)[:,-k:]
indices = np.concatenate((idx_bot, idx_top), axis=1)
vals = arr[np.repeat(np.arange(rows), 2*k), indices.ravel()].reshape(rows,2*k)
return vals, indices
评论列表
文章目录