dataloader.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:CNN-parallel 作者: harpribot 项目源码 文件源码
def extract_train_and_validation_data(self,num_labels):
        data = pd.read_csv(self.train_data_filename, header=0).values
        # convert to Numpy array forms
        feature_vec = data[0::,1::]
        labels = data[0::,0]

        # mean normalize features
        min_max_scaler = preprocessing.MinMaxScaler()
        feature_vec = min_max_scaler.fit_transform(feature_vec.T).T

        # convert to one hot form for labels
        labels_onehot = (np.arange(num_labels) == labels[:, None]).astype(np.float32)

        # divide data into train and validation data
        self.train_X, self.val_X, self.train_y, self.val_y = train_test_split(\
                                            feature_vec, labels_onehot,
                                            test_size=0.2, random_state=42)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号