cyberpunk_trainer.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:third_person_im 作者: bstadie 项目源码 文件源码
def shuffle_to_training_data(self, expert_data, on_policy_data, expert_fail_data):
        data = np.vstack([expert_data['data'], on_policy_data['data'], expert_fail_data['data']])
        classes = np.vstack([expert_data['classes'], on_policy_data['classes'], expert_fail_data['classes']])
        domains = np.vstack([expert_data['domains'], on_policy_data['domains'], expert_fail_data['domains']])

        sample_range = data.shape[0]*data.shape[1]
        all_idxs = np.random.permutation(sample_range)

        t_steps = data.shape[1]

        data_matrix = np.zeros(shape=(sample_range, self.im_height, self.im_width, self.im_channels))
        data_matrix_two = np.zeros(shape=(sample_range, self.im_height, self.im_width, self.im_channels))
        class_matrix = np.zeros(shape=(sample_range, 2))
        dom_matrix = np.zeros(shape=(sample_range, 2))
        for one_idx, iter_step in zip(all_idxs, range(0, sample_range)):
            traj_key = np.floor(one_idx/t_steps)
            time_key = one_idx % t_steps
            time_key_plus_one = min(time_key + 3, t_steps-1)
            data_matrix[iter_step, :, :, :] = data[traj_key, time_key, :, :, :]
            data_matrix_two[iter_step, :, :, :] = data[traj_key, time_key_plus_one, :, :, :]
            class_matrix[iter_step, :] = classes[traj_key, time_key, :]
            dom_matrix[iter_step, :] = domains[traj_key, time_key, :]
        return data_matrix, data_matrix_two, dom_matrix, class_matrix
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号