def _local_add(ltenss):
"""Computes the local tensors of a sum of MPArrays (except for the boundary
tensors)
:param ltenss: List of arrays with `ndim > 1`
:returns: Correct local tensor representation
"""
shape = ltenss[0].shape
# NOTE These are currently disabled due to real speed issues.
# if __debug__:
# for lt in ltenss[1:]:
# assert_array_equal(shape[1:-1], lt.shape[1:-1])
# FIXME: Find out whether the following code does the same as
# :func:`block_diag()` used by :func:`_local_sum_identity` and
# which implementation is faster if so.
newshape = (sum(lt.shape[0] for lt in ltenss), )
newshape += shape[1:-1]
newshape += (sum(lt.shape[-1] for lt in ltenss), )
res = np.zeros(newshape, dtype=max(lt.dtype for lt in ltenss))
pos_l, pos_r = 0, 0
for lt in ltenss:
pos_l_new, pos_r_new = pos_l + lt.shape[0], pos_r + lt.shape[-1]
res[pos_l:pos_l_new, ..., pos_r:pos_r_new] = lt
pos_l, pos_r = pos_l_new, pos_r_new
return res
评论列表
文章目录