def line_search(Y, signs, direction, current_loss, ls_tries):
'''
Performs a backtracking line search, starting from Y and W, in the
direction direction.
'''
alpha = 1.
if current_loss is None:
current_loss = loss(Y, signs)
for _ in range(ls_tries):
Y_new = np.dot(expm(alpha * direction), Y)
new_loss = loss(Y_new, signs)
if new_loss < current_loss:
return True, Y_new, new_loss, alpha
alpha /= 2.
else:
return False, Y_new, new_loss, alpha
评论列表
文章目录