def collect_trajs_for_cost(self, n_trajs, pol, env, dom, cls):
paths = []
#print(n_trajs)
for iter_step in range(0, n_trajs):
paths.append(self.cyberpunk_rollout(agent=pol, env=env, max_path_length=self.horizon,
reward_extractor=None))
data_matrix = tensor_utils.stack_tensor_list([p['im_observations'] for p in paths])
class_matrix = np.tile(cls, (n_trajs, self.horizon, 1))
dom_matrix = np.tile(dom, (n_trajs, self.horizon, 1))
#data_matrix = np.zeros(shape=(n_trajs, self.horizon, self.im_height, self.im_width, self.im_channels))
#class_matrix = np.zeros(shape=(n_trajs, self.horizon, 2))
#dom_matrix = np.zeros(shape=(n_trajs, self.horizon, 2))
#for path, path_step in zip(paths, range(0, len(paths))):
# for sub_path, time_step in zip(path['im_observations'], range(0, self.horizon)):
# data_matrix[path_step, time_step, :, :, :] = sub_path
# class_matrix[path_step, time_step, :] = path['class']
# dom_matrix[path_step, time_step, :] = path['dom']
return dict(data=data_matrix, classes=class_matrix, domains=dom_matrix)
评论列表
文章目录