extras.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lambda-numba 作者: rlhotovy 项目源码 文件源码
def _median(a, axis=None, out=None, overwrite_input=False):
    if overwrite_input:
        if axis is None:
            asorted = a.ravel()
            asorted.sort()
        else:
            a.sort(axis=axis)
            asorted = a
    else:
        asorted = sort(a, axis=axis)

    if axis is None:
        axis = 0
    elif axis < 0:
        axis += asorted.ndim

    if asorted.ndim == 1:
        idx, odd = divmod(count(asorted), 2)
        return asorted[idx + odd - 1 : idx + 1].mean(out=out)

    counts = count(asorted, axis=axis)
    h = counts // 2

    # create indexing mesh grid for all but reduced axis
    axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape)
                 if i != axis]
    ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')

    # insert indices of low and high median
    ind.insert(axis, np.maximum(0, h - 1))
    low = asorted[tuple(ind)]
    ind[axis] = h
    high = asorted[tuple(ind)]

    # duplicate high if odd number of elements so mean does nothing
    odd = counts % 2 == 1
    if asorted.ndim > 1:
        np.copyto(low, high, where=odd)
    elif odd:
        low = high

    if np.issubdtype(asorted.dtype, np.inexact):
        # avoid inf / x = masked
        s = np.ma.sum([low, high], axis=0, out=out)
        np.true_divide(s.data, 2., casting='unsafe', out=s.data)
    else:
        s = np.ma.mean([low, high], axis=0, out=out)
    return s
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号