ddpg_network.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:SRLF 作者: Fritz449 项目源码 文件源码
def create_critic(self, name, state_input, action_input, reuse=False):
        hidden = state_input
        weights = []
        with tf.variable_scope(name, reuse=reuse):
            for index, n_hidden in enumerate(self.n_hiddens):
                if index == 1:
                    hidden = tf.concat([hidden, action_input], axis=1)
                hidden, layer_weights = denselayer("hidden_critic_{}".format(index), hidden, n_hidden,
                                                   self.nonlinearity, tf.truncated_normal_initializer())
                weights += layer_weights

            value, layer_weights = denselayer("value", hidden, 1,
                                              w_initializer=tf.random_uniform_initializer(-3e-3, 3e-3))
            value = tf.reshape(value, [-1])
            weights += layer_weights
            weight_phs = [tf.placeholder(tf.float32, shape=w.get_shape()) for w in weights]
        return value, weights, weight_phs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号