def select_episodes(dataset, n_episodes, parse=False):
"""
Return the first `n_episodes` episodes in the provided dataset.
Args:
dataset (list): the dataset to consider;
n_episodes (int): the number of episodes to pick from the dataset;
parse (bool, False): whether to parse the dataset to return.
Returns:
A subset of the dataset containing the first `n_episodes` episodes.
"""
assert n_episodes >= 0, 'Number of episodes must be greater than or equal' \
'to zero.'
if n_episodes == 0:
return np.array([[]])
dataset = np.array(dataset)
last_idxs = np.argwhere(dataset[:, -1] == 1).ravel()
sub_dataset = dataset[:last_idxs[n_episodes - 1] + 1, :]
return sub_dataset if not parse else parse_dataset(sub_dataset)
评论列表
文章目录