a3c.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:Safe-RL-Benchmark 作者: befelix 项目源码 文件源码
def __init__(self, policy, rate, train=True):
        self.rate = rate

        with tf.variable_scope('value_estimator'):
            self.X = tf.placeholder(policy.dtype,
                                    shape=policy.X.shape,
                                    name='X')
            self.V = tf.placeholder(policy.dtype,
                                    shape=[None],
                                    name='V')

            self.W = policy.init_weights((policy.layers[0], 1))

            self.V_est = tf.matmul(self.X, self.W)

            self.losses = tf.squared_difference(self.V_est, self.V)
            self.loss = tf.reduce_sum(self.losses, name='loss')

            if train:
                self.opt = tf.train.RMSPropOptimizer(rate, 0.99, 0.0, 1e-6)
                self.grads_and_vars = self.opt.compute_gradients(self.loss)
                self.grads_and_vars = [(g, v) for g, v in self.grads_and_vars
                                       if g is not None]
                self.update = self.opt.apply_gradients(self.grads_and_vars)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号