sampler.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:bnn-analysis 作者: myshkov 项目源码 文件源码
def _normalise_data(self):
        self.train_x_mean = np.zeros(self.input_dim)
        self.train_x_std = np.ones(self.input_dim)

        self.train_y_mean = np.zeros(self.output_dim)
        self.train_y_std = np.ones(self.output_dim)

        if self.normalise_data:
            self.train_x_mean = np.mean(self.train_x, axis=0)
            self.train_x_std = np.std(self.train_x, axis=0)
            self.train_x_std[self.train_x_std == 0] = 1.

            self.train_x = (self.train_x - np.full(self.train_x.shape, self.train_x_mean, dtype=np.float32)) / \
                           np.full(self.train_x.shape, self.train_x_std, dtype=np.float32)

            self.test_x = (self.test_x - np.full(self.test_x.shape, self.train_x_mean, dtype=np.float32)) / \
                          np.full(self.test_x.shape, self.train_x_std, dtype=np.float32)

            self.train_y_mean = np.mean(self.train_y, axis=0)
            self.train_y_std = np.std(self.train_y, axis=0)

            if self.train_y_std == 0:
                self.train_y_std[self.train_y_std == 0] = 1.

            self.train_y = (self.train_y - self.train_y_mean) / self.train_y_std
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号