def CAL_v(name, label_p, label_n, oracle, n_features, ftype, test_x, test_y):
online = OnlineBase(name, label_p, label_n, oracle, n_features, ftype, error=.5)
x, y = online.collect_pts(100, -1)
i = 0
q = online.get_n_query()
C_range = np.logspace(-2, 5, 10, base=10)
gamma_range = np.logspace(-5, 1, 10, base=10)
param_grid = dict(gamma=gamma_range, C=C_range)
while q < 3500:
i += 1
# h_ = ex.fit(x, y)
cv = StratifiedShuffleSplit(y, n_iter=5, test_size=0.2, random_state=42)
grid = GridSearchCV(svm.SVC(), param_grid=param_grid, cv=cv, verbose=0, n_jobs=-1)
grid.fit(x, y)
h_ = grid.best_estimator_
online_ = OnlineBase('', label_p, label_n, h_.predict, n_features, ftype, error=.1)
x_, _ = online_.collect_pts(10, 200)
if x_ is not None and len(x_) > 0:
x.extend(x_)
y.extend(oracle(x_))
q += online_.get_n_query()
pred_y = h_.predict(test_x)
print len(x), q, sm.accuracy_score(test_y, pred_y)
评论列表
文章目录