baseline.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Steal-ML 作者: ftramer 项目源码 文件源码
def run(dataset_name, n_features):
    base_dir = os.path.join(os.getcwd(), '../targets/%s/' % dataset_name)
    model_file = os.path.join(base_dir, 'train.scale.model')

    result = Result('baseline')
    n_repeat = 10
    for repeat in range(0, n_repeat):
        print 'Round %d of %d'% (repeat, n_repeat - 1)

        # load model and collect QSV
        ex = LibSVMOnline(dataset_name, model_file, (1, -1), n_features, 'uniform', 1e-1)
        # generate test score
        X_test, y_test = load_svmlight_file(os.path.join(base_dir, 'test.scale'), n_features)
        X_test = X_test.todense()
        train_x, train_y = [], []
        for i in result.index:
            q_by_u = result.Q_by_U[i]
            ex.collect_up_to_budget(q_by_u * (n_features + 1))
            train_x.extend(ex.pts_near_b)
            train_y.extend(ex.pts_near_b_labels)
            base = Baseline(ex.batch_predict, (train_x, train_y), (X_test, y_test), n_features)

            L_unif, L_test = base.do()

            result.L_unif[i].append(L_unif)
            result.L_test[i].append(L_test)
            result.nquery[i].append(ex.get_n_query())

            # print ex.get_n_query() / (n_features + 1), ',', L_unif, ',', L_test

    print result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号