my_utils.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Y8M 作者: mpekalski 项目源码 文件源码
def bottom_top_k_along_row(arr, k, ordered=True):
    """ bottom and top k of a 2d np.array, along the rows    
    http://stackoverflow.com/questions/6910641/how-to-get-indices-of-n-maximum-values-in-a-numpy-array/18691983
    """    
    assert k>0, "bottom_top_k_along_row/column() requires k>0."
    rows = arr.shape[0]
    if ordered:
        tmp = np.argsort(arr, axis=1)      
        idx_bot = tmp[:, :k]
        idx_top = tmp[:,-k:]
    else:
        idx_bot = np.argpartition(arr, k, axis=1)[:,:k]
        idx_top = np.argpartition(arr, -k, axis=1)[:,-k:]

    indices = np.concatenate((idx_bot, idx_top), axis=1)
    vals = arr[np.repeat(np.arange(rows), 2*k), indices.ravel()].reshape(rows,2*k)
    return vals, indices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号