datasets.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:importance-sampling 作者: idiap 项目源码 文件源码
def load_data(self):
        # Create the data using magic numbers to approximate the figure in
        # canevet_icml2016
        x = np.linspace(0, 1, self.N).astype(np.float32)
        ones = np.ones_like(x).astype(int)
        boundary = np.sin(4*(x + 0.5)**5)/3 + 0.5

        data = np.empty(shape=[self.N, self.N, 3], dtype=np.float32)
        data[:, :, 0] = 1-x
        for i in range(self.N):
            data[i, :, 1] = 1-x[i]
            data[i, :, 2] = 1 / (1 + np.exp(self.smooth*(x - boundary[i])))
            data[i, :, 2] = np.random.binomial(ones, data[i, :, 2])
        data = data.reshape(-1, 3)
        np.random.shuffle(data)

        # Create train and test arrays
        split = int(len(data)*self.test_split)
        X_train = data[:-split, :2]
        y_train = data[:-split, 2]
        X_test = data[-split:, :2]
        y_test = data[-split:, 2]

        return (X_train, y_train), (X_test, y_test)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号