def get_demo_data(env):
# env = wrappers.Monitor(env, '/tmp/CartPole-v0', force=True)
# agent.restore_model()
with tf.variable_scope('get_demo_data'):
agent = DQfDDDQN(env, DDQNConfig())
e = 0
while True:
done = False
score = 0 # sum of reward in one episode
state = env.reset()
demo = []
while done is False:
action = agent.egreedy_action(state) # e-greedy action for train
next_state, reward, done, _ = env.step(action)
score += reward
reward = reward if not done or score == 499 else -100
agent.perceive([state, action, reward, next_state, done, 0.0]) # 0. means it is not a demo data
demo.append([state, action, reward, next_state, done, 1.0]) # record the data that could be expert-data
agent.train_Q_network(update=False)
state = next_state
if done:
if score == 500: # expert demo data
demo = set_n_step(demo, Config.trajectory_n)
agent.demo_buffer.extend(demo)
agent.sess.run(agent.update_target_net)
print("episode:", e, " score:", score, " demo_buffer:", len(agent.demo_buffer),
" memory length:", len(agent.replay_buffer), " epsilon:", agent.epsilon)
if len(agent.demo_buffer) >= Config.demo_buffer_size:
agent.demo_buffer = deque(itertools.islice(agent.demo_buffer, 0, Config.demo_buffer_size))
break
e += 1
with open(Config.DEMO_DATA_PATH, 'wb') as f:
pickle.dump(agent.demo_buffer, f, protocol=2)
评论列表
文章目录