def split_train_val(self, subject_list):
L = len(subject_list)
np.random.seed(42)
np.random.shuffle(subject_list)
L_train = int(np.round(self.train_size * L))
L_val = int(np.round((1 - self.train_size) * L - np.finfo(float).eps))
if L_val == 0:
return subject_list, subject_list
else:
return subject_list[:L_train], subject_list[L_train:L_train + L_val]
评论列表
文章目录