def CAL(name, label_p, label_n, oracle, n_features, ftype, test_x, test_y):
online = OnlineBase(name, label_p, label_n, oracle, n_features, ftype, error=.5)
q = online.get_n_query()
C_range = np.logspace(-2, 5, 10, base=10)
gamma_range = np.logspace(-5, 1, 10, base=10)
param_grid = dict(gamma=gamma_range, C=C_range)
x, y = online.collect_pts(100, -1)
i = 0
cv = StratifiedShuffleSplit(y, n_iter=5, test_size=0.2, random_state=42)
grid = GridSearchCV(svm.SVC(), param_grid=param_grid, cv=cv, verbose=0, n_jobs=-1)
grid.fit(x, y)
h_ = grid.best_estimator_
while q < 3500:
i += 1
# h_ = ex.fit(x, y)
online_ = OnlineBase('', label_p, label_n, h_.predict, n_features, ftype, error=.1)
x_ = online_.collect_one_pair()
if x_ is not None and len(x_) > 0:
for _x in x_:
x.append(_x)
y.append(1)
cv = StratifiedShuffleSplit(y, n_iter=5, test_size=0.2, random_state=42)
grid = GridSearchCV(svm.SVC(), param_grid=param_grid, cv=cv, verbose=0, n_jobs=-1)
grid.fit(x, y)
h1 = grid.best_estimator_
s1 = sm.accuracy_score(y, h1.predict(x))
y[-1] = -1
cv = StratifiedShuffleSplit(y, n_iter=5, test_size=0.2, random_state=42)
grid = GridSearchCV(svm.SVC(), param_grid=param_grid, cv=cv, verbose=0, n_jobs=-1)
grid.fit(x, y)
h2 = grid.best_estimator_
s2 = sm.accuracy_score(y, h2.predict(x))
if s1 >= .99 and s2 >= .99:
print 'branch 1'
y[-1] = oracle(x_)[0]
elif s1 >= .99 and s2 < .99:
print 'branch 2'
y[-1] = 1
elif s1 < .99 and s2 >= .99:
print 'branch 3'
y[-1] = -1
else:
print 'branch 4: ', s1, s2
del x[-1]
del y[-1]
continue
if y[-1] == 1:
h_ = h1
else:
h_ = h2
q += online_.get_n_query()
pred_y = h_.predict(test_x)
print q, sm.accuracy_score(test_y, pred_y)
评论列表
文章目录