norm_query.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:search-MjoLniR 作者: wikimedia 项目源码 文件源码
def _binary_sim(matrix):
    """Compute a jaccard similarity matrix.

    Vecorization based on: https://stackoverflow.com/a/40579567

    Parameters
    ----------
    matrix : np.array

    Returns
    -------
    np.array
        matrix of shape (n_rows, n_rows) giving the similarity
        between rows of the input matrix.
    """
    # Generate the indices of the lower triangle of our result matrix.
    # The diagonal is offset by -1 because the identity in a similarity
    # matrix is always 1.
    r, c = np.tril_indices(matrix.shape[0], -1)

    # Particularly large groups can blow out memory usage. Chunk the calculation
    # into steps that require no more than ~100MB of memory at a time.
    # We have 4 2d intermediate arrays in memory at a given time, plus the
    # input and output:
    #  p1 = max_rows * matrix.shape[1]
    #  p2 = max_rows * matrix.shape[1]
    #  intersection = max_rows * matrix.shape[1] * 4
    #  union = max_rows * matrix.shape[1] * 8
    # This adds up to:
    #  memory usage = max_rows * matrix.shape[1] * 14
    mem_limit = 100 * pow(2, 20)
    max_rows = mem_limit / (14 * matrix.shape[1])
    out = np.eye(matrix.shape[0])
    for c_batch, r_batch in _batch(c, r, max_rows):
        # Use those indices to build two matrices that contains all
        # the rows we need to do a similarity comparison on
        p1 = matrix[c_batch]
        p2 = matrix[r_batch]
        # Run the main jaccard calculation
        intersection = np.logical_and(p1, p2).sum(1)
        union = np.logical_or(p1, p2).sum(1).astype(np.float64)
        # Build the result matrix with 1's on the diagonal
        # Insert the result of our similarity calculation at their original indices
        out[c_batch, r_batch] = intersection / union
    # Above only populated half of the matrix, the mirrored diagonal contains
    # only zeros. Fix that up by transposing. Adding the transposed matrix double
    # counts the diagonal so subtract that back out. We could skip this step and
    # leave half the matrix empty, but it takes a fraction of a ms to be correct
    # even on mid-sized inputs (~200 queries).
    return out + out.T - np.diag(np.diag(out))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号