def run():
recreate_directory_structure()
tf.reset_default_graph()
sess = tf.Session()
# sess = tf_debug.LocalCLIDebugWrapperSession(sess)
# sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
with sess:
with tf.device("/cpu:0"):
global_step = tf.Variable(0, dtype=tf.int32, name='global_episodes', trainable=False)
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.lr)
if FLAGS.use_conv:
global_network = ConvNetwork('global', None)
else:
global_network = ACNetwork('global', None)
# num_agents = multiprocessing.cpu_count()
num_agents = FLAGS.nb_concurrent
agents = []
envs = []
for i in range(num_agents):
gym_env = gym.make(FLAGS.game)
# if FLAGS.monitor:
# gym_env = gym.wrappers.Monitor(gym_env, FLAGS.experiments_dir + '/worker_{}'.format(i), force=True)
envs.append(gym_env)
for i in range(num_agents):
agents.append(Agent(envs[i], i, optimizer, global_step))
saver = tf.train.Saver(max_to_keep=5)
coord = tf.train.Coordinator()
if FLAGS.resume:
ckpt = tf.train.get_checkpoint_state(os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name))
print("Loading Model from {}".format(ckpt.model_checkpoint_path))
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
agent_threads = []
for agent in agents:
thread = threading.Thread(target=(lambda: agent.play(sess, coord, saver)))
thread.start()
agent_threads.append(thread)
while True:
if FLAGS.show_training:
for env in envs:
# time.sleep(1)
# with main_lock:
env.render()
coord.join(agent_threads)
评论列表
文章目录