def test_contextual_optimizers_follow_standard_protocol():
for name, ContextualOptimizer in ALL_CONTEXTUALOPTIMIZERS:
opt = ContextualOptimizer()
n_params = 1
n_context_dims = 1
opt.init(n_params, n_context_dims)
context = opt.get_desired_context()
if context is None:
context = np.zeros(n_context_dims)
opt.set_context(context)
assert_false(opt.is_behavior_learning_done())
params = np.empty(n_params)
opt.get_next_parameters(params)
assert_true(np.isfinite(params).all())
opt.set_evaluation_feedback(np.array([0.0]))
policy = opt.best_policy()
assert_true(np.isfinite(policy(context)).all())
assert_pickle(name, opt)
评论列表
文章目录