agent.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:ERL 作者: NoListen 项目源码 文件源码
def play(self, n_step=10000, n_episode=100, test_ep=True, render=False):
    if test_ep == None:
      test_ep = self.ep_end

    test_history = History(self.config)

    if not self.display:
      gym_dir = './tmp/%s-%s' % (self.env_name, get_time())
      # self.env.env.monitor.start(gym_dir)
      monitor = gym.wrappers.Monitor(self.env.env, gym_dir)

    best_reward, best_idx = 0, 0
    ep_rewards = []
    for idx in tqdm(xrange(n_episode),ncols=70):
      screen = monitor.reset()
      screen = imresize(rgb2gray(screen), (110, 84))
      screen = screen[18:102, :]
      current_reward = 0

      # if not os.path.exists("fuck/epoch%i" % idx):
        # os.mkdir("fuck/epoch%i" % idx)
      for _ in range(self.history_length):
        test_history.add(screen)

      for t in range(n_step):
        # 1. predict
        action = self.predict(test_history.get(), test_ep)
        # 2. act
        screen, reward, terminal, _ = monitor.step(action)
        screen = imresize(rgb2gray(screen), (110, 84))
        screen = screen[18:102, :]
        # 3. observe
        test_history.add(screen)

        current_reward += reward
        if terminal:

          break

      print "GET REWARD", current_reward
      ep_rewards.append(current_reward)
      if current_reward > best_reward:
        best_reward = current_reward
        best_idx = idx


    print "="*30
    print " [%d] Best reward : %d" % (best_idx, best_reward),
    print "Average reward: %f" %  np.mean(ep_rewards)
    print "="*30

    if not self.display:
      monitor.close()
      #gym.upload(gym_dir, writeup='https://github.com/devsisters/DQN-tensorflow', api_key='')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号