def learn(self):
# hard replace parameters
if self.a_replace_counter % REPLACE_ITER_A == 0:
self.sess.run([tf.assign(t, e) for t, e in zip(self.at_params, self.ae_params)])
if self.c_replace_counter % REPLACE_ITER_C == 0:
self.sess.run([tf.assign(t, e) for t, e in zip(self.ct_params, self.ce_params)])
self.a_replace_counter += 1; self.c_replace_counter += 1
indices = np.random.choice(MEMORY_CAPACITY, size=BATCH_SIZE)
bt = self.memory[indices, :]
bs = bt[:, :self.s_dim]
ba = bt[:, self.s_dim: self.s_dim + self.a_dim]
br = bt[:, -self.s_dim - 1: -self.s_dim]
bs_ = bt[:, -self.s_dim:]
self.sess.run(self.atrain, {self.S: bs})
self.sess.run(self.ctrain, {self.S: bs, self.a: ba, self.R: br, self.S_: bs_})
评论列表
文章目录