mparray.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:mpnum 作者: dseuss 项目源码 文件源码
def _local_dot(ltens_l, ltens_r, axes):
    """Computes the local tensors of a dot product `dot(l, r)`.

    Besides computing the normal dot product, this function rearranges the
    virtual legs in such a way that the result is a valid local tensor again.

    :param ltens_l: Array with `ndim > 1`
    :param ltens_r: Array with `ndim > 1`
    :param axes: Axes to compute dot product using the convention of
        :func:`numpy.tensordot()`. Note that these correspond to the true
        (and not the just the physical) legs of the local tensors
    :returns: Correct local tensor representation

    """
    # number of contracted legs need to be the same
    clegs_l = len(axes[0]) if isinstance(axes[0], collections.Sequence) else 1
    clegs_r = len(axes[1]) if isinstance(axes[0], collections.Sequence) else 1
    assert clegs_l == clegs_r, \
        "Number of contracted legs differ: {} != {}".format(clegs_l, clegs_r)
    res = np.tensordot(ltens_l, ltens_r, axes=axes)
    # Rearrange the virtual-dimension legs
    res = np.rollaxis(res, ltens_l.ndim - clegs_l, 1)
    res = np.rollaxis(res, ltens_l.ndim - clegs_l,
                      ltens_l.ndim + ltens_r.ndim - clegs_l - clegs_r - 1)
    return res.reshape((ltens_l.shape[0] * ltens_r.shape[0], ) +
                       res.shape[2:-2] +
                       (ltens_l.shape[-1] * ltens_r.shape[-1],))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号