policy_graphs.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:gymmeforce 作者: lgvaz 项目源码 文件源码
def dense_policy_graph(inputs,
                       env_config,
                       activation_fn=tf.nn.tanh,
                       scope='policy_graph',
                       reuse=None,
                       trainable=True):
    with tf.variable_scope(scope, reuse=reuse):
        net = inputs
        net = tf.contrib.layers.flatten(net)
        net = tf.layers.dense(
            inputs=net,
            units=64,
            activation=activation_fn,
            kernel_initializer=variance_scaling_initializer(factor=1),
            trainable=trainable)
        net = tf.layers.dense(
            inputs=net,
            units=64,
            activation=activation_fn,
            kernel_initializer=variance_scaling_initializer(factor=1),
            trainable=trainable)

        if env_config['action_space'] == 'continuous':
            mean = tf.layers.dense(
                inputs=net,
                units=env_config['num_actions'],
                kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                name='mean',
                trainable=trainable)
            logstd = tf.get_variable(
                'logstd', (1, env_config['num_actions']),
                tf.float32,
                initializer=tf.zeros_initializer(),
                trainable=trainable)

            return mean, logstd

        if env_config['action_space'] == 'discrete':
            logits = tf.layers.dense(
                inputs=net,
                units=env_config['num_actions'],
                name='logits',
                trainable=trainable)
            return logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号