def build_network_rnn(self):
self.states = tf.placeholder(tf.float32, [None] + list(self.env.observation_space.shape), name="states") # Observation
# self.n_states = tf.placeholder(tf.float32, shape=[None], name="n_states") # Observation
self.a_n = tf.placeholder(tf.float32, name="a_n") # Discrete action
self.adv_n = tf.placeholder(tf.float32, name="adv_n") # Advantage
n_states = tf.shape(self.states)[:1]
states = tf.expand_dims(flatten(self.states), [0])
enc_cell = tf.contrib.rnn.GRUCell(self.config["n_hidden_units"])
L1, _ = tf.nn.dynamic_rnn(cell=enc_cell, inputs=states,
sequence_length=n_states, dtype=tf.float32)
L1 = L1[0]
mu, sigma = mu_sigma_layer(L1, 1)
self.normal_dist = tf.contrib.distributions.Normal(mu, sigma)
self.action = self.normal_dist.sample(1)
self.action = tf.clip_by_value(self.action, self.env.action_space.low[0], self.env.action_space.high[0])
评论列表
文章目录