def perturb_model(args, model, random_seed, env):
"""
Modifies the given model with a pertubation of its parameters,
as well as the negative perturbation, and returns both perturbed
models.
"""
new_model = ES(env.observation_space.shape[0],
env.action_space, args.small_net)
anti_model = ES(env.observation_space.shape[0],
env.action_space, args.small_net)
new_model.load_state_dict(model.state_dict())
anti_model.load_state_dict(model.state_dict())
np.random.seed(random_seed)
for (k, v), (anti_k, anti_v) in zip(new_model.es_params(),
anti_model.es_params()):
eps = np.random.normal(0, 1, v.size())
v += torch.from_numpy(args.sigma*eps).float()
anti_v += torch.from_numpy(args.sigma*-eps).float()
return [new_model, anti_model]
评论列表
文章目录