def test_sample_contexts_from_distribution():
env = Catapult(segments=[(0, 0), (20, 0)], context_interval=(0, 20),
context_distribution=uniform(5, 10), random_state=0)
env.init()
contexts = np.empty(1000)
for i in range(contexts.shape[0]):
context = env.request_context(None)
contexts[i] = context[0]
norm_dist = uniform(0.25, 0.5)
assert_true(np.all(0.25 <= contexts))
assert_true(np.all(contexts <= 0.75))
mean, var = norm_dist.stats("mv")
assert_almost_equal(np.mean(contexts), mean, places=1)
assert_almost_equal(np.var(contexts), var, places=1)
评论列表
文章目录