a3c.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:DeepRL 作者: arnomoonens 项目源码 文件源码
def __init__(self, state_shape, n_hidden, summary=True):
        super(CriticNetwork, self).__init__()
        self.state_shape = state_shape
        self.n_hidden = n_hidden

        with tf.variable_scope("critic"):
            self.states = tf.placeholder("float", [None] + self.state_shape, name="states")
            self.r = tf.placeholder(tf.float32, [None], name="r")

            L1 = tf.contrib.layers.fully_connected(
                inputs=self.states,
                num_outputs=self.n_hidden,
                activation_fn=tf.tanh,
                weights_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.02),
                biases_initializer=tf.zeros_initializer(),
                scope="L1")

            self.value = tf.reshape(linear(L1, 1, "value", normalized_columns_initializer(1.0)), [-1])

            self.loss = tf.reduce_sum(tf.square(self.value - self.r))
            self.summary_loss = self.loss
            self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号