def test_sumup(nr_sites, local_dim, rank, rgen, dtype):
mpas = [factory.random_mpa(nr_sites, local_dim, 3, dtype=dtype, randstate=rgen)
for _ in range(rank if rank is not np.nan else 1)]
sum_naive = ft.reduce(mp.MPArray.__add__, mpas)
sum_mp = mp.sumup(mpas)
assert_array_almost_equal(sum_naive.to_array(), sum_mp.to_array())
assert all(r <= 3 * rank for r in sum_mp.ranks)
assert(sum_mp.dtype is dtype)
weights = rgen.randn(len(mpas))
summands = [w * mpa for w, mpa in zip(weights, mpas)]
sum_naive = ft.reduce(mp.MPArray.__add__, summands)
sum_mp = mp.sumup(mpas, weights=weights)
assert_array_almost_equal(sum_naive.to_array(), sum_mp.to_array())
assert all(r <= 3 * rank for r in sum_mp.ranks)
assert(sum_mp.dtype is dtype)
评论列表
文章目录