def build_networks(self):
with tf.variable_scope("shared"):
self.states = tf.placeholder(tf.float32, [None] + list(self.envs[0].observation_space.shape), name="states")
self.action_taken = tf.placeholder(tf.float32, name="action_taken")
self.advantage = tf.placeholder(tf.float32, name="advantage")
if self.config["feature_extraction"]:
self.L1 = tf.contrib.layers.fully_connected(
inputs=self.states,
num_outputs=self.config["n_hidden_units"],
activation_fn=tf.tanh,
weights_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.02),
biases_initializer=tf.zeros_initializer(),
scope="L1")
else:
self.L1 = self.states
self.knowledge_base = tf.Variable(tf.truncated_normal([self.L1.get_shape()[-1].value, self.config["n_sparse_units"]], mean=0.0, stddev=0.02), name="knowledge_base")
self.shared_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
评论列表
文章目录