test_experience.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:reinforceflow 作者: dbobrenko 项目源码 文件源码
def test_prop_replay_distribution():
    priors = [20000.0, 30000.0, 1000.0, 49000.0, 0.0]
    cap = 256
    batch_size = 32
    sample_amount = 2000
    replay = ProportionalReplay(capacity=cap, min_size=cap, batch_size=batch_size, alpha=1, beta=1)
    s = int(np.sum(priors))
    expected_priors = np.asarray(priors) / s
    received_priors = [0] * len(priors)
    for o, p in enumerate(priors):
        replay.add(obs=o, action=0, reward=0, obs_next=0, term=False, priority=p)
    for i in range(sample_amount):
        obs, a, r, obs_next, terms, idxs, importance = replay.sample()
        for o in obs:
            received_priors[o] += 1
    received_priors = np.asarray(received_priors) / (sample_amount*batch_size)
    npt.assert_almost_equal(expected_priors, received_priors, decimal=2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号