a3c.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:reinforceflow 作者: dbobrenko 项目源码 文件源码
def __init__(self, input_space, output_space, layer_sizes=(512, 512, 512), trainable=True):
        if isinstance(input_space, Tuple) or isinstance(output_space, Tuple):
            raise ValueError('For tuple action and observation spaces '
                             'consider implementing custom network architecture.')
        self._input_ph = tf.placeholder('float32', shape=[None] + list(input_space.shape),
                                        name='inputs')
        end_points = {}
        net = layers.flatten(self._input_ph)
        for i, units in enumerate(layer_sizes):
            name = 'fc%d' % i
            net = layers.fully_connected(net, num_outputs=units, activation_fn=tf.nn.relu,
                                         trainable=trainable, scope=name)
            end_points[name] = net
        gaussian = tf.random_normal_initializer
        v = layers.fully_connected(net, num_outputs=1,
                                   activation_fn=None,
                                   weights_initializer=gaussian(0.0, 0.1),
                                   biases_initializer=gaussian(0.05, 0.1),
                                   scope='out_value',
                                   trainable=trainable)
        end_points['out_value'] = tf.squeeze(v)
        header_endpoints = make_a3c_header(net, input_space, output_space, trainable)
        end_points.update(header_endpoints)
        self.end_points = end_points
        self.output_policy = self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号