policygradient.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Safe-RL-Benchmark 作者: befelix 项目源码 文件源码
def _estimate_gradient(self, policy):
        env = self.environment
        var = self.var
        # store current policy parameter
        parameter = policy.parameters
        par_dim = policy.parameter_space.dimension

        # using forward differences
        trace = env.rollout(policy)
        j_ref = sum([x[2] for x in trace]) / len(trace)

        dj = np.zeros((2 * par_dim))
        dv = np.append(np.eye(par_dim), -np.eye(par_dim), axis=0)
        dv *= var

        for n in range(par_dim):
            variation = dv[n]

            policy.parameters = parameter + variation
            trace_n = env.rollout(policy)

            jn = sum([x[2] for x in trace]) / len(trace_n)

            dj[n] = j_ref - jn

        grad = solve(dv.T.dot(dv), dv.T.dot(dj))

        # reset current policy parameter
        policy.parameters = parameter

        return grad
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号