active_learning.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:Steal-ML 作者: ftramer 项目源码 文件源码
def CAL(name, label_p, label_n, oracle, n_features, ftype, test_x, test_y):
    online = OnlineBase(name, label_p, label_n, oracle, n_features, ftype, error=.5)

    q = online.get_n_query()
    C_range = np.logspace(-2, 5, 10, base=10)
    gamma_range = np.logspace(-5, 1, 10, base=10)
    param_grid = dict(gamma=gamma_range, C=C_range)

    x, y = online.collect_pts(100, -1)

    i = 0

    cv = StratifiedShuffleSplit(y, n_iter=5, test_size=0.2, random_state=42)
    grid = GridSearchCV(svm.SVC(), param_grid=param_grid, cv=cv, verbose=0, n_jobs=-1)
    grid.fit(x, y)
    h_ = grid.best_estimator_
    while q < 3500:
        i += 1
        # h_ = ex.fit(x, y)

        online_ = OnlineBase('', label_p, label_n, h_.predict, n_features, ftype, error=.1)
        x_ = online_.collect_one_pair()
        if x_ is not None and len(x_) > 0:
            for _x in x_:
                x.append(_x)
                y.append(1)
                cv = StratifiedShuffleSplit(y, n_iter=5, test_size=0.2, random_state=42)
                grid = GridSearchCV(svm.SVC(), param_grid=param_grid, cv=cv, verbose=0, n_jobs=-1)
                grid.fit(x, y)
                h1 = grid.best_estimator_
                s1 = sm.accuracy_score(y, h1.predict(x))

                y[-1] = -1
                cv = StratifiedShuffleSplit(y, n_iter=5, test_size=0.2, random_state=42)
                grid = GridSearchCV(svm.SVC(), param_grid=param_grid, cv=cv, verbose=0, n_jobs=-1)
                grid.fit(x, y)
                h2 = grid.best_estimator_
                s2 = sm.accuracy_score(y, h2.predict(x))
                if s1 >= .99 and s2 >= .99:
                    print 'branch 1'
                    y[-1] = oracle(x_)[0]
                elif s1 >= .99 and s2 < .99:
                    print 'branch 2'
                    y[-1] = 1
                elif s1 < .99 and s2 >= .99:
                    print 'branch 3'
                    y[-1] = -1
                else:
                    print 'branch 4: ', s1, s2
                    del x[-1]
                    del y[-1]
                    continue

            if y[-1] == 1:
                h_ = h1
            else:
                h_ = h2

        q += online_.get_n_query()
        pred_y = h_.predict(test_x)
        print q, sm.accuracy_score(test_y, pred_y)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号