a2c.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:nn_q_learning_tensorflow 作者: EndingCredits 项目源码 文件源码
def cnn(self, state, input_dims, num_actions):
        w = {}
        initializer = tf.truncated_normal_initializer(0, 0.02)
        activation_fn = tf.nn.relu

        state = tf.transpose(state, perm=[0, 2, 3, 1])

        l1, w['l1_w'], w['l1_b'] = conv2d(state,
          32, [8, 8], [4, 4], initializer, activation_fn, 'NHWC', name='l1')
        l2, w['l2_w'], w['l2_b'] = conv2d(l1,
          64, [4, 4], [2, 2], initializer, activation_fn, 'NHWC', name='l2')

        shape = l2.get_shape().as_list()
        l2_flat = tf.reshape(l2, [-1, reduce(lambda x, y: x * y, shape[1:])])

        l3, w['l3_w'], w['l3_b'] = linear(l2_flat, 256, activation_fn=activation_fn, name='value_hid')


        value, w['val_w_out'], w['val_w_b'] = linear(l3, 1, name='value_out')
        V = tf.reshape(value, [-1])

        pi_, w['pi_w_out'], w['pi_w_b'] = \
            linear(l3, num_actions, activation_fn=tf.nn.softmax, name='pi_out')

        sums = tf.tile(tf.expand_dims(tf.reduce_sum(pi_, 1), 1), [1, num_actions])
        pi = pi_ / sums

        #A3C is l1 = (16, [8,8], [4,4], ReLu), l2 = (32, [4,4], [2,2], ReLu), l3 = (256, Conn, ReLu), V = (1, Conn, Lin), pi = (#act, Conn, Softmax)
        return pi, V, [ v for v in w.values() ]




# Adapted from github.com/devsisters/DQN-tensorflow/
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号