classifier_tf.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def split_episodes(self, episode_paths, n_train, n_valid, n_test, seed=None, use_all=True):
        """Split episodes between training, validation and test sets.

        seed: random seed (have split performed consistently every time)"""
        if seed is not None:
            random_state = np.random.get_state()
            np.random.seed(seed)
            np.random.shuffle(episode_paths)
            np.random.set_state(random_state)
        else:
            np.random.shuffle(episode_paths)
        if use_all:
            multiplier = float(len(episode_paths)) / float(n_train + n_valid + n_test)
            n_train = int(math.floor(multiplier * n_train))
            n_valid = int(math.floor(multiplier * n_valid))
            n_test = int(math.floor(multiplier * n_test))

        assert n_train + n_valid + n_test <= len(episode_paths)
        return (episode_paths[:n_train], episode_paths[n_train:n_train + n_valid],
                episode_paths[n_train + n_test:n_train + n_test + n_test])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号