def __call__(self, params, params_args, obj, idxs, alpha, prop_mode):
params_dict = unflatten_dict(params, params_args)
f, grad_dict = obj.objective_function(
params_dict, idxs, alpha=alpha, prop_mode=prop_mode)
g, _ = flatten_dict(grad_dict)
g_is_fin = np.isfinite(g)
if np.all(g_is_fin):
self.previous_x = params
return f, g
else:
print("Warning: inf or nan in gradient: replacing with zeros")
return f, np.where(g_is_fin, g, 0.)
# def objective_wrapper(params, params_args, obj, idxs, alpha):
# params_dict = unflatten_dict(params, params_args)
# f, grad_dict = obj.objective_function(
# params_dict, idxs, alpha=alpha)
# g, _ = flatten_dict(grad_dict)
# g_is_fin = np.isfinite(g)
# if np.all(g_is_fin):
# return f, g
# else:
# print("Warning: inf or nan in gradient: replacing with zeros")
# return f, np.where(g_is_fin, g, 0.)
评论列表
文章目录