def test_network(self, reward_type):
"""
Network Testing is done on the validation data, and the testing data is to see if the representation
is learnt efficiently. Testing data is not meant to check RL accuracy. The validation data is meant for that.
Testing data is meant to check if the algorithm learnt feature representations of objects.
:return: None
"""
imagenet = self.__get_image_loader(NUM_VALIDATION_IMAGES)
imagenet.load_next_image()
state = self.__get_initial_state()
# initialize the environment
env = Environment(imagenet.get_image(), state, self.grid_dim, imagenet.get_puzzle_pieces(),
IMAGE_HEIGHT, WINDOW_SIZE, SLIDING_STRIDE,
self.input_channels, self.state_type)
reward_list = []
image_diff_list = []
episode_reward = 0.0
while True:
state_new, a_t, reward, terminal = self.__play_one_move(state, env, reward_type, epsilon=0.0)
episode_reward = reward + GAMMA * episode_reward
# if terminal state has reached, then move to the next image
if terminal:
image_diff_list.append(env.get_normalized_image_diff())
reward_list.append(episode_reward)
image_present = imagenet.load_next_image()
if image_present:
env.update_puzzle_pieces(imagenet.get_puzzle_pieces())
env.update_original_image(imagenet.get_image())
episode_reward = 0.0
else:
break
# update the old values
state = env.get_state()
# display the reward and image matching performance statistics
performance_statistics(image_diff_list, reward_list)
评论列表
文章目录