def test_prop_replay_update():
priors = np.array([20000.0, 30000.0, 1000.0, 49000.0, 0.0])
cap = 2048
batch_size = 32
sample_amount = 2000
replay = ProportionalReplay(capacity=cap, min_size=cap, batch_size=batch_size, alpha=1, beta=1)
s = int(np.sum(priors))
expected_priors = priors / s
received_priors = [0] * len(priors)
for o, p in enumerate(priors):
replay.add(obs=o, action=0, reward=0, obs_next=0, term=False)
replay.update(list(range(len(priors))), priors)
for i in range(sample_amount):
obs, a, r, obs_next, terms, idxs, importance = replay.sample()
for o in obs:
received_priors[o] += 1
received_priors = np.asarray(received_priors) / (sample_amount*batch_size)
npt.assert_almost_equal(expected_priors, received_priors, decimal=2)
评论列表
文章目录