def process(self, sess):
"""
process grabs a rollout that's been produced by the thread runner,
and updates the parameters. The update is then sent to the parameter
server.
"""
sess.run(self.sync) # copy weights from shared to local
rollout = self.pull_batch_from_queue()
batch = process_rollout(rollout, self.gamma, lambda_=1.0)
should_compute_summary = self.task == 0 and self.local_steps % 11 == 0
if should_compute_summary:
fetches = [self.summary_op, self.train_op, self.global_step]
else:
fetches = [self.train_op, self.global_step]
feed_dict = {
self.local_network.x: batch.si,
self.ac: batch.a,
self.adv: batch.adv,
self.r: batch.r,
}
for k, v in zip(self.local_network.state_in, batch.features):
feed_dict[k] = v
fetched = sess.run(fetches, feed_dict=feed_dict)
if should_compute_summary:
self.summary_writer.add_summary(tf.Summary.FromString(fetched[0]), fetched[-1])
self.summary_writer.flush()
self.local_steps += 1
评论列表
文章目录