mparray.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:mpnum 作者: dseuss 项目源码 文件源码
def sumup(mpas, weights=None):
    """Returns the sum of the MPArrays in ``mpas``. Same as

    .. code-block:: python

        functools.reduce(mp.MPArray.__add__, mpas)

    but should be faster as we can get rid of intermediate allocations.

    :param mpas: Iterator over :class:`~MPArray`
    :returns: Sum of ``mpas``

    """
    mpas = list(mpas)
    length = len(mpas[0])
    assert all(len(mpa) == length for mpa in mpas)

    if length == 1:
        if weights is None:
            return MPArray([sum(mpa.lt[0] for mpa in mpas)])
        else:
            return MPArray([sum(w * mpa.lt[0] for w, mpa in zip(weights, mpas))])

    ltensiter = [iter(mpa.lt) for mpa in mpas]
    if weights is None:
        ltens = [np.concatenate([next(lt) for lt in ltensiter], axis=-1)]
    else:
        ltens = [np.concatenate([w * next(lt)
                                 for w, lt in zip(weights, ltensiter)], axis=-1)]
    ltens += [_local_add([next(lt) for lt in ltensiter])
              for _ in range(length - 2)]
    ltens += [np.concatenate([next(lt) for lt in ltensiter], axis=0)]

    return MPArray(ltens)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号