test_storage.py 文件源码

python
阅读 51 收藏 0 点赞 0 评论 0

项目:pyabc 作者: neuralyzer 项目源码 文件源码
def test_sum_stats_save_load(history: History):
    arr = sp.random.rand(10)
    arr2 = sp.random.rand(10, 2)
    particle_population = [
        ValidParticle(0, Parameter({"a": 23, "b": 12}),
                      .2,
                      [.1],
                      [{"ss1": .1, "ss2": arr2,
                        "ss3": example_df(),
                        "rdf0": r["iris"]}]),
        ValidParticle(0,
                      Parameter({"a": 23, "b": 12}),
                      .2,
                      [.1],
                      [{"ss12": .11, "ss22": arr,
                        "ss33": example_df(),
                        "rdf": r["mtcars"]}])
    ]
    history.append_population(0, 42, particle_population, 2, ["m1", "m2"])
    weights, sum_stats = history.get_sum_stats(0, 0)
    assert (weights == 0.5).all()
    assert sum_stats[0]["ss1"] == .1
    assert (sum_stats[0]["ss2"] == arr2).all()
    assert (sum_stats[0]["ss3"] == example_df()).all().all()
    assert (sum_stats[0]["rdf0"] == pandas2ri.ri2py(r["iris"])).all().all()
    assert sum_stats[1]["ss12"] == .11
    assert (sum_stats[1]["ss22"] == arr).all()
    assert (sum_stats[1]["ss33"] == example_df()).all().all()
    assert (sum_stats[1]["rdf"] == pandas2ri.ri2py(r["mtcars"])).all().all()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号