synthetic_env.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:bnn-analysis 作者: myshkov 项目源码 文件源码
def create_training_test_sets(self):
        # training set
        scale = self.data_interval_right - self.data_interval_left
        train_x = sp.stats.truncnorm.rvs(-2, 2, scale=0.25 * scale, size=self.data_size).astype(np.float32)
        train_x = np.sort(train_x)
        train_y = self.true_f(train_x) + 0.2 * np.random.randn(self.data_size)

        self.train_x = [train_x.reshape((train_x.shape[0], 1))]
        self.train_y = [train_y.reshape((train_y.shape[0], 1))]

        # test set
        # scale = self.test_data_interval_right - self.test_data_interval_left
        # test_x = sp.stats.truncnorm.rvs(-2, 2, scale=0.25 * scale, size=self.test_data_size).astype(np.float32)
        # test_x = np.sort(test_x)
        # test_y = self.true_f(test_x)

        self.test_x = np.arange(self.view_xrange[0], self.view_xrange[1], 0.01, dtype=np.float32)
        self.test_y = self.true_f(self.test_x)

        self.test_x = [self.test_x.reshape((self.test_x.shape[0], 1))]
        self.test_y = [self.test_y.reshape((self.test_y.shape[0], 1))]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号