def balance_data(data):
lengths = [len(s.split(' ')) for s in data]
data = data[np.array(lengths)<=70]
lengths = [len(s.split(' ')) for s in data]
bins = np.array([0, 5, 8, 12, 17, 21, 26, 70])
share_dev = 0.05
labels = np.digitize(lengths, bins) - np.ones_like(lengths)
sss = StratifiedShuffleSplit(n_splits=2, test_size=0.05, random_state=0)
for train_index, test_index in sss.split(lengths, labels):
X_train, X_test = data[train_index], data[test_index]
return X_train, X_test
评论列表
文章目录