def _get_validation_split(self):
train = pd.read_csv(self.train_csv_file)
# mapping labels to integer classes
flatten = lambda l: [item for sublist in l for item in sublist]
labels = list(set(flatten([l.split(' ') for l in train['tags'].values])))
label_map = {l: i for i, l in enumerate(labels)}
y_train = []
for f,tags in (train.values):
targets = np.zeros(len(label_map))
for t in tags.split(' '):
targets[label_map[t]] = 1
y_train.append(targets)
y_train = np.array(y_train, np.uint8)
trn_index = []
val_index = []
index = np.arange(len(train))
for i in (range(len(label_map))):
sss = StratifiedShuffleSplit(n_splits=2, test_size=self.validation_split, random_state=i)
for train_index, test_index in sss.split(index,y_train[:,i]):
X_train, X_test = index[train_index], index[test_index]
# to ensure there is no repetetion within each split and between the splits
trn_index = trn_index + list(set(X_train) - set(trn_index) - set(val_index))
val_index = val_index + list(set(X_test) - set(val_index) - set(trn_index))
return np.array(trn_index), np.array(val_index)
评论列表
文章目录