policygradient.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Safe-RL-Benchmark 作者: befelix 项目源码 文件源码
def _estimate_gradient(self, policy):
        env = self.environment

        parameter = policy.parameters
        par_dim = policy.parameter_space.dimension

        dj = np.zeros((par_dim,))
        dv = np.eye(par_dim) * self.var / 2

        for n in range(par_dim):
            variation = dv[n]

            policy.parameters = parameter + variation
            trace_n = env.rollout(policy)

            policy.parameters = parameter - variation
            trace_n_ref = env.rollout(policy)

            jn = sum([x[2] for x in trace_n]) / len(trace_n)
            jn_ref = sum([x[2] for x in trace_n_ref]) / len(trace_n_ref)

            dj[n] = jn - jn_ref

        grad = solve(dv.T.dot(dv), dv.T.dot(dj))
        policy.parameters = parameter

        return grad
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号