dataset.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:mushroom 作者: carloderamo 项目源码 文件源码
def select_episodes(dataset, n_episodes, parse=False):
    """
    Return the first `n_episodes` episodes in the provided dataset.

    Args:
        dataset (list): the dataset to consider;
        n_episodes (int): the number of episodes to pick from the dataset;
        parse (bool, False): whether to parse the dataset to return.

    Returns:
        A subset of the dataset containing the first `n_episodes` episodes.

    """
    assert n_episodes >= 0, 'Number of episodes must be greater than or equal' \
                            'to zero.'
    if n_episodes == 0:
        return np.array([[]])

    dataset = np.array(dataset)
    last_idxs = np.argwhere(dataset[:, -1] == 1).ravel()
    sub_dataset = dataset[:last_idxs[n_episodes - 1] + 1, :]

    return sub_dataset if not parse else parse_dataset(sub_dataset)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号