def cnn(self, state, input_dims, num_actions):
w = {}
initializer = tf.truncated_normal_initializer(0, 0.02)
activation_fn = tf.nn.relu
state = tf.transpose(state, perm=[0, 2, 3, 1])
l1, w['l1_w'], w['l1_b'] = conv2d(state,
32, [8, 8], [4, 4], initializer, activation_fn, 'NHWC', name='l1')
l2, w['l2_w'], w['l2_b'] = conv2d(l1,
64, [4, 4], [2, 2], initializer, activation_fn, 'NHWC', name='l2')
shape = l2.get_shape().as_list()
l2_flat = tf.reshape(l2, [-1, reduce(lambda x, y: x * y, shape[1:])])
l3, w['l3_w'], w['l3_b'] = linear(l2_flat, 256, activation_fn=activation_fn, name='value_hid')
value, w['val_w_out'], w['val_w_b'] = linear(l3, 1, name='value_out')
V = tf.reshape(value, [-1])
pi_, w['pi_w_out'], w['pi_w_b'] = \
linear(l3, num_actions, activation_fn=tf.nn.softmax, name='pi_out')
sums = tf.tile(tf.expand_dims(tf.reduce_sum(pi_, 1), 1), [1, num_actions])
pi = pi_ / sums
#A3C is l1 = (16, [8,8], [4,4], ReLu), l2 = (32, [4,4], [2,2], ReLu), l3 = (256, Conn, ReLu), V = (1, Conn, Lin), pi = (#act, Conn, Softmax)
return pi, V, [ v for v in w.values() ]
# Adapted from github.com/devsisters/DQN-tensorflow/
评论列表
文章目录