def test_sum_stats_save_load(history: History):
arr = sp.random.rand(10)
arr2 = sp.random.rand(10, 2)
particle_population = [
ValidParticle(0, Parameter({"a": 23, "b": 12}),
.2,
[.1],
[{"ss1": .1, "ss2": arr2,
"ss3": example_df(),
"rdf0": r["iris"]}]),
ValidParticle(0,
Parameter({"a": 23, "b": 12}),
.2,
[.1],
[{"ss12": .11, "ss22": arr,
"ss33": example_df(),
"rdf": r["mtcars"]}])
]
history.append_population(0, 42, particle_population, 2, ["m1", "m2"])
weights, sum_stats = history.get_sum_stats(0, 0)
assert (weights == 0.5).all()
assert sum_stats[0]["ss1"] == .1
assert (sum_stats[0]["ss2"] == arr2).all()
assert (sum_stats[0]["ss3"] == example_df()).all().all()
assert (sum_stats[0]["rdf0"] == pandas2ri.ri2py(r["iris"])).all().all()
assert sum_stats[1]["ss12"] == .11
assert (sum_stats[1]["ss22"] == arr).all()
assert (sum_stats[1]["ss33"] == example_df()).all().all()
assert (sum_stats[1]["rdf"] == pandas2ri.ri2py(r["mtcars"])).all().all()
评论列表
文章目录