cyberpunk_trainer.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:third_person_im 作者: bstadie 项目源码 文件源码
def collect_trajs_for_cost(self, n_trajs, pol, env, dom, cls):
        paths = []
        #print(n_trajs)
        for iter_step in range(0, n_trajs):
            paths.append(self.cyberpunk_rollout(agent=pol, env=env, max_path_length=self.horizon,
                                                reward_extractor=None))

        data_matrix = tensor_utils.stack_tensor_list([p['im_observations'] for p in paths])
        class_matrix = np.tile(cls, (n_trajs, self.horizon, 1))
        dom_matrix = np.tile(dom, (n_trajs, self.horizon, 1))

        #data_matrix = np.zeros(shape=(n_trajs, self.horizon, self.im_height, self.im_width, self.im_channels))
        #class_matrix = np.zeros(shape=(n_trajs, self.horizon, 2))
        #dom_matrix = np.zeros(shape=(n_trajs, self.horizon, 2))
        #for path, path_step in zip(paths, range(0, len(paths))):
        #    for sub_path, time_step in zip(path['im_observations'], range(0, self.horizon)):
        #        data_matrix[path_step, time_step, :, :, :] = sub_path
        #        class_matrix[path_step, time_step, :] = path['class']
        #        dom_matrix[path_step, time_step, :] = path['dom']

        return dict(data=data_matrix, classes=class_matrix, domains=dom_matrix)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号