def predict(self, path):
if 'env_infos' in path.keys() and 'full_path' in path['env_infos'].keys():
expanded_path = tensor_utils.flatten_first_axis_tensor_dict(path['env_infos']['full_path'])
else: # when it comes from log_diagnostics it's already expanded (or if it was never aggregated)
expanded_path = path
bonus = self.visitation_bonus * self.predict_count(expanded_path) + \
self.dist_from_reset_bonus * self.predict_dist_from_reset(expanded_path)
if self.snn_H_bonus: # I need the if because the snn bonus is only available when there are latents
bonus += self.snn_H_bonus * self.predict_entropy(expanded_path)
total_bonus = bonus + self.survival_bonus * np.ones_like(bonus)
if 'env_infos' in path.keys() and 'full_path' in path['env_infos'].keys():
aggregated_bonus = []
full_path_rewards = path['env_infos']['full_path']['rewards']
total_steps = 0
for sub_rewards in full_path_rewards:
aggregated_bonus.append(np.sum(total_bonus[total_steps:total_steps + len(sub_rewards)]))
total_steps += len(sub_rewards)
total_bonus = aggregated_bonus
return np.array(total_bonus)
评论列表
文章目录