def test_partialdot(nr_sites, local_dim, rank, rgen, dtype):
# Only for at least two sites, we can apply an operator to a part
# of a chain.
if nr_sites < 2:
return
part_sites = nr_sites // 2
start_at = min(2, nr_sites // 2)
mpo = factory.random_mpa(nr_sites, (local_dim, local_dim), rank,
randstate=rgen, dtype=dtype)
op = mpo.to_array_global().reshape((local_dim**nr_sites,) * 2)
mpo_part = factory.random_mpa(part_sites, (local_dim, local_dim), rank,
randstate=rgen, dtype=dtype)
op_part = mpo_part.to_array_global().reshape((local_dim**part_sites,) * 2)
op_part_embedded = np.kron(
np.kron(np.eye(local_dim**start_at), op_part),
np.eye(local_dim**(nr_sites - part_sites - start_at)))
prod1 = np.dot(op, op_part_embedded)
prod2 = np.dot(op_part_embedded, op)
prod1_mpo = mp.partialdot(mpo, mpo_part, start_at=start_at)
prod2_mpo = mp.partialdot(mpo_part, mpo, start_at=start_at)
prod1_mpo = prod1_mpo.to_array_global().reshape((local_dim**nr_sites,) * 2)
prod2_mpo = prod2_mpo.to_array_global().reshape((local_dim**nr_sites,) * 2)
assert_array_almost_equal(prod1, prod1_mpo)
assert_array_almost_equal(prod2, prod2_mpo)
assert prod1_mpo.dtype == dtype
assert prod2_mpo.dtype == dtype
评论列表
文章目录