def __init__(self, ob_space, ac_space, layers=[256], **kwargs):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
rank = len(ob_space)
if rank == 3: # pixel input
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "c{}".format(i + 1), [3, 3], [2, 2]))
elif rank == 1: # plain features
#x = tf.nn.elu(linear(x, 256, "l1", normalized_columns_initializer(0.01)))
pass
else:
raise TypeError("observation space must have rank 1 or 3, got %d" % rank)
x = flatten(x)
for i, layer in enumerate(layers):
x = tf.nn.elu(linear(x, layer, "l{}".format(i + 1), tf.contrib.layers.xavier_initializer()))
self.logits = linear(x, ac_space, "action", tf.contrib.layers.xavier_initializer())
self.vf = tf.reshape(linear(x, 1, "value", tf.contrib.layers.xavier_initializer()), [-1])
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
self.state_in = []
评论列表
文章目录