def get_sample(self, N=600, scale=False):
all_data = self.pre_process(self.file_name)
#print('data_type: ' + str(all_data.dtypes))
all_data = all_data.values
xs = all_data[:, 2:]
y = all_data[:, 1]
if scale:
xs = preprocessing.scale(xs)
if N != -1:
perm = np.random.permutation(xs.shape[0])
xs = xs[perm]
y = y[perm]
xs_train, xs_test = np.split(xs, [N])
y_train, y_test = np.split(y, [N])
return xs_train, xs_test, y_train, y_test
else:
return xs, y
评论列表
文章目录