argmaxk.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:hyperstar 作者: nlpub 项目源码 文件源码
def argmaxk_rows_opt1(arr, k=10, sort=False):
    """
    Optimized implementation. When sort=False it is equal to argmaxk_rows_basic. When sort=True and k << arr.shape[1],
    it is should be faster, because we argsort only subarray of k max elements from each row of arr (arr.shape[0] x k) instead of
    the whole array arr (arr.shape[0] x arr.shape[1]).
    """
    best_inds = np.argpartition(arr, kth=-k, axis=1)[:, -k:]  # column indices of k max elements in each row (m x k)
    if not sort:
        return best_inds
    # generate row indices corresponding to best_ids (just current row id in each row) (m x k)
    rows = np.arange(best_inds.shape[0], dtype=np.intp)[:, np.newaxis].repeat(best_inds.shape[1], axis=1)
    best_elems = arr[rows, best_inds]  # select k max elements from each row using advanced indexing (m x k)
    # indices which sort each row of best_elems in descending order (m x k)
    best_elems_inds = np.argsort(best_elems, axis=1)[:, ::-1]
    # reorder best_indices so that arr[i, sorted_best_inds[i,:]] will be sorted in descending order
    sorted_best_inds = best_inds[rows, best_elems_inds]
    return sorted_best_inds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号