def run(settings):
recreate_subdirectory_structure(settings)
tf.reset_default_graph()
with tf.device("/cpu:0"):
global_step = tf.Variable(0, dtype=tf.int32, name='global_episodes', trainable=False)
optimizer = tf.train.AdamOptimizer(learning_rate=settings["lr"])
global_network = ACNetwork('global', None)
num_agents = 1
agents = []
envs = []
for i in range(num_agents):
if settings["game"] == '11arms':
this_env = ElevenArms()
else:
this_env = TwoArms(settings["game"])
envs.append(this_env)
for i in range(num_agents):
agents.append(Agent(envs[i], i, optimizer, global_step, settings))
saver = tf.train.Saver(max_to_keep=5)
with tf.Session() as sess:
coord = tf.train.Coordinator()
if FLAGS.resume:
if FLAGS.hypertune:
ckpt = tf.train.get_checkpoint_state(settings["checkpoint_dir"])
else:
ckpt = tf.train.get_checkpoint_state(settings["load_from"])
print("Loading Model from {}".format(ckpt.model_checkpoint_path))
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
agent_threads = []
for agent in agents:
agent_play = lambda: agent.play(sess, coord, saver)
thread = threading.Thread(target=agent_play)
thread.start()
agent_threads.append(thread)
coord.join(agent_threads)
评论列表
文章目录