def split_episodes(self, episode_paths, n_train, n_valid, n_test, seed=None, use_all=True):
"""Split episodes between training, validation and test sets.
seed: random seed (have split performed consistently every time)"""
if seed is not None:
random_state = np.random.get_state()
np.random.seed(seed)
np.random.shuffle(episode_paths)
np.random.set_state(random_state)
else:
np.random.shuffle(episode_paths)
if use_all:
multiplier = float(len(episode_paths)) / float(n_train + n_valid + n_test)
n_train = int(math.floor(multiplier * n_train))
n_valid = int(math.floor(multiplier * n_valid))
n_test = int(math.floor(multiplier * n_test))
assert n_train + n_valid + n_test <= len(episode_paths)
return (episode_paths[:n_train], episode_paths[n_train:n_train + n_valid],
episode_paths[n_train + n_test:n_train + n_test + n_test])
评论列表
文章目录