mparray.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:mpnum 作者: dseuss 项目源码 文件源码
def _local_add(ltenss):
    """Computes the local tensors of a sum of MPArrays (except for the boundary
    tensors)

    :param ltenss: List of arrays with `ndim > 1`
    :returns: Correct local tensor representation

    """
    shape = ltenss[0].shape
    # NOTE These are currently disabled due to real speed issues.
    #  if __debug__:
    #      for lt in ltenss[1:]:
    #          assert_array_equal(shape[1:-1], lt.shape[1:-1])

    # FIXME: Find out whether the following code does the same as
    # :func:`block_diag()` used by :func:`_local_sum_identity` and
    # which implementation is faster if so.
    newshape = (sum(lt.shape[0] for lt in ltenss), )
    newshape += shape[1:-1]
    newshape += (sum(lt.shape[-1] for lt in ltenss), )
    res = np.zeros(newshape, dtype=max(lt.dtype for lt in ltenss))

    pos_l, pos_r = 0, 0
    for lt in ltenss:
        pos_l_new, pos_r_new = pos_l + lt.shape[0], pos_r + lt.shape[-1]
        res[pos_l:pos_l_new, ..., pos_r:pos_r_new] = lt
        pos_l, pos_r = pos_l_new, pos_r_new
    return res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号