def sumup(mpas, weights=None):
"""Returns the sum of the MPArrays in ``mpas``. Same as
.. code-block:: python
functools.reduce(mp.MPArray.__add__, mpas)
but should be faster as we can get rid of intermediate allocations.
:param mpas: Iterator over :class:`~MPArray`
:returns: Sum of ``mpas``
"""
mpas = list(mpas)
length = len(mpas[0])
assert all(len(mpa) == length for mpa in mpas)
if length == 1:
if weights is None:
return MPArray([sum(mpa.lt[0] for mpa in mpas)])
else:
return MPArray([sum(w * mpa.lt[0] for w, mpa in zip(weights, mpas))])
ltensiter = [iter(mpa.lt) for mpa in mpas]
if weights is None:
ltens = [np.concatenate([next(lt) for lt in ltensiter], axis=-1)]
else:
ltens = [np.concatenate([w * next(lt)
for w, lt in zip(weights, ltensiter)], axis=-1)]
ltens += [_local_add([next(lt) for lt in ltensiter])
for _ in range(length - 2)]
ltens += [np.concatenate([next(lt) for lt in ltensiter], axis=0)]
return MPArray(ltens)
评论列表
文章目录