agent.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:harpreif 作者: harpribot 项目源码 文件源码
def test_network(self, reward_type):
        """
        Network Testing is done on the validation data, and the testing data is to see if the representation
        is learnt efficiently. Testing data is not meant to check RL accuracy. The validation data is meant for that.
        Testing data is meant to check if the algorithm learnt feature representations of objects.
        :return: None
        """
        imagenet = self.__get_image_loader(NUM_VALIDATION_IMAGES)
        imagenet.load_next_image()

        state = self.__get_initial_state()
        # initialize the environment
        env = Environment(imagenet.get_image(), state, self.grid_dim, imagenet.get_puzzle_pieces(),
                          IMAGE_HEIGHT, WINDOW_SIZE, SLIDING_STRIDE,
                          self.input_channels, self.state_type)
        reward_list = []
        image_diff_list = []
        episode_reward = 0.0
        while True:
            state_new, a_t, reward, terminal = self.__play_one_move(state, env, reward_type, epsilon=0.0)
            episode_reward = reward + GAMMA * episode_reward
            # if terminal state has reached, then move to the next image
            if terminal:
                image_diff_list.append(env.get_normalized_image_diff())
                reward_list.append(episode_reward)

                image_present = imagenet.load_next_image()

                if image_present:
                    env.update_puzzle_pieces(imagenet.get_puzzle_pieces())
                    env.update_original_image(imagenet.get_image())
                    episode_reward = 0.0
                else:
                    break
            # update the old values
            state = env.get_state()
        # display the reward and image matching performance statistics
        performance_statistics(image_diff_list, reward_list)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号