def shuffle_to_training_data(self, expert_data, on_policy_data, expert_fail_data):
data = np.vstack([expert_data['data'], on_policy_data['data'], expert_fail_data['data']])
classes = np.vstack([expert_data['classes'], on_policy_data['classes'], expert_fail_data['classes']])
domains = np.vstack([expert_data['domains'], on_policy_data['domains'], expert_fail_data['domains']])
sample_range = data.shape[0]*data.shape[1]
all_idxs = np.random.permutation(sample_range)
t_steps = data.shape[1]
data_matrix = np.zeros(shape=(sample_range, self.im_height, self.im_width, self.im_channels))
data_matrix_two = np.zeros(shape=(sample_range, self.im_height, self.im_width, self.im_channels))
class_matrix = np.zeros(shape=(sample_range, 2))
dom_matrix = np.zeros(shape=(sample_range, 2))
for one_idx, iter_step in zip(all_idxs, range(0, sample_range)):
traj_key = np.floor(one_idx/t_steps)
time_key = one_idx % t_steps
time_key_plus_one = min(time_key + 3, t_steps-1)
data_matrix[iter_step, :, :, :] = data[traj_key, time_key, :, :, :]
data_matrix_two[iter_step, :, :, :] = data[traj_key, time_key_plus_one, :, :, :]
class_matrix[iter_step, :] = classes[traj_key, time_key, :]
dom_matrix[iter_step, :] = domains[traj_key, time_key, :]
return data_matrix, data_matrix_two, dom_matrix, class_matrix
评论列表
文章目录