def _fit_owl_fista(X, y, w, loss, max_iter=500, max_linesearch=20, eta=2.0,
tol=1e-3, verbose=0):
# least squares loss
def sfunc(coef, grad=False):
y_scores = safe_sparse_dot(X, coef)
if grad:
obj, lp = loss(y, y_scores, return_derivative=True)
grad = safe_sparse_dot(X.T, lp)
return obj, grad
else:
return loss(y, y_scores)
def nsfunc(coef, L):
return prox_owl(coef, w / L)
coef = np.zeros(X.shape[1])
return fista(sfunc, nsfunc, coef, max_iter, max_linesearch,
eta, tol, verbose)
评论列表
文章目录