train.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:pytorch-es 作者: atgambardella 项目源码 文件源码
def perturb_model(args, model, random_seed, env):
    """
    Modifies the given model with a pertubation of its parameters,
    as well as the negative perturbation, and returns both perturbed
    models.
    """
    new_model = ES(env.observation_space.shape[0],
                   env.action_space, args.small_net)
    anti_model = ES(env.observation_space.shape[0],
                    env.action_space, args.small_net)
    new_model.load_state_dict(model.state_dict())
    anti_model.load_state_dict(model.state_dict())
    np.random.seed(random_seed)
    for (k, v), (anti_k, anti_v) in zip(new_model.es_params(),
                                        anti_model.es_params()):
        eps = np.random.normal(0, 1, v.size())
        v += torch.from_numpy(args.sigma*eps).float()
        anti_v += torch.from_numpy(args.sigma*-eps).float()
    return [new_model, anti_model]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号