def a3c(game_name, num_threads=8, restore=None, save_path='model'):
processes = []
envs = []
for _ in range(num_threads+1):
gym_env = gym.make(game_name)
if game_name == 'CartPole-v0':
env = CustomGymClassicControl(game_name)
else:
print "Assuming ATARI game and playing with pixels"
env = CustomGym(game_name)
envs.append(env)
# Separate out the evaluation environment
evaluation_env = envs[0]
envs = envs[1:]
with tf.Session() as sess:
agent = Agent(session=sess,
action_size=envs[0].action_size, model='mnih',
optimizer=tf.train.AdamOptimizer(INITIAL_LEARNING_RATE))
# Create a saver, and only keep 2 checkpoints.
saver = tf.train.Saver(max_to_keep=2)
T_queue = Queue.Queue()
# Either restore the parameters or don't.
if restore is not None:
saver.restore(sess, save_path + '-' + str(restore))
last_T = restore
print "T was:", last_T
T_queue.put(last_T)
else:
sess.run(tf.global_variables_initializer())
T_queue.put(0)
summary = Summary(save_path, agent)
# Create a process for each worker
for i in range(num_threads):
processes.append(threading.Thread(target=async_trainer, args=(agent,
envs[i], sess, i, T_queue, summary, saver, save_path,)))
# Create a process to evaluate the agent
processes.append(threading.Thread(target=evaluator, args=(agent,
evaluation_env, sess, T_queue, summary, saver, save_path,)))
# Start all the processes
for p in processes:
p.daemon = True
p.start()
# Until training is finished
while not training_finished:
sleep(0.01)
# Join the processes, so we get this thread back.
for p in processes:
p.join()
# Returns sum(rewards[i] * gamma**i)
评论列表
文章目录