get_data.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:Doubly-Stochastic-DGP 作者: ICL-SML 项目源码 文件源码
def get_regression_data(name, split, data_path=data_path):
    path = '{}{}.csv'.format(data_path, name)

    if not os.path.isfile(path):
        download(name +'.csv', data_path=data_path)

    data = pandas.read_csv(path, header=None).values

    if name in ['energy', 'naval']:
        # there are two Ys for these, but take only the first
        X_full = data[:, :-2]
        Y_full = data[:, -2]
    else:
        X_full = data[:, :-1]
        Y_full = data[:, -1]


    X, Y, Xs, Ys = make_split(X_full, Y_full, split)

    ############# whiten inputs 
    X_mean, X_std = np.average(X, 0), np.std(X, 0)+1e-6

    X = (X - X_mean)/X_std
    Xs = (Xs - X_mean)/X_std

    return  X, Y[:, None], Xs, Ys[:, None]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号