def run():
recreate_directory_structure()
tf.reset_default_graph()
sess = tf.Session()
# sess = tf_debug.LocalCLIDebugWrapperSession(sess)
# sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
with sess:
with tf.device("/cpu:0"):
global_step = tf.Variable(0, dtype=tf.int32, name='global_episodes', trainable=False)
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.lr)
# num_agents = multiprocessing.cpu_count()
num_agents = FLAGS.nb_concurrent
agents = []
envs = []
for i in range(num_agents):
gym_env = gym.make(FLAGS.game)
# if FLAGS.monitor:
# gym_env = gym.wrappers.Monitor(gym_env, FLAGS.experiments_dir + '/worker_{}'.format(i), force=True)
if FLAGS.game not in flags.SUPPORTED_ENVS:
gym_env = atari_environment.AtariEnvironment(gym_env=gym_env, resized_width=FLAGS.resized_width,
resized_height=FLAGS.resized_height,
agent_history_length=FLAGS.agent_history_length)
FLAGS.nb_actions = len(gym_env.gym_actions)
envs.append(gym_env)
global_network = FUNNetwork('global', None)
for i in range(num_agents):
agents.append(Agent(envs[i], i, optimizer, global_step))
saver = tf.train.Saver(max_to_keep=5)
coord = tf.train.Coordinator()
if FLAGS.resume:
ckpt = tf.train.get_checkpoint_state(os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name))
print("Loading Model from {}".format(ckpt.model_checkpoint_path))
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
agent_threads = []
for agent in agents:
thread = threading.Thread(target=(lambda: agent.play(sess, coord, saver)))
thread.start()
agent_threads.append(thread)
while True:
if FLAGS.show_training:
for env in envs:
# time.sleep(1)
# with main_lock:
env.render()
coord.join(agent_threads)
评论列表
文章目录