mparray_test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:mpnum 作者: dseuss 项目源码 文件源码
def test_partialdot(nr_sites, local_dim, rank, rgen, dtype):
    # Only for at least two sites, we can apply an operator to a part
    # of a chain.
    if nr_sites < 2:
        return
    part_sites = nr_sites // 2
    start_at = min(2, nr_sites // 2)

    mpo = factory.random_mpa(nr_sites, (local_dim, local_dim), rank,
                             randstate=rgen, dtype=dtype)
    op = mpo.to_array_global().reshape((local_dim**nr_sites,) * 2)
    mpo_part = factory.random_mpa(part_sites, (local_dim, local_dim), rank,
                                  randstate=rgen, dtype=dtype)
    op_part = mpo_part.to_array_global().reshape((local_dim**part_sites,) * 2)
    op_part_embedded = np.kron(
        np.kron(np.eye(local_dim**start_at), op_part),
        np.eye(local_dim**(nr_sites - part_sites - start_at)))

    prod1 = np.dot(op, op_part_embedded)
    prod2 = np.dot(op_part_embedded, op)
    prod1_mpo = mp.partialdot(mpo, mpo_part, start_at=start_at)
    prod2_mpo = mp.partialdot(mpo_part, mpo, start_at=start_at)
    prod1_mpo = prod1_mpo.to_array_global().reshape((local_dim**nr_sites,) * 2)
    prod2_mpo = prod2_mpo.to_array_global().reshape((local_dim**nr_sites,) * 2)

    assert_array_almost_equal(prod1, prod1_mpo)
    assert_array_almost_equal(prod2, prod2_mpo)
    assert prod1_mpo.dtype == dtype
    assert prod2_mpo.dtype == dtype
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号