utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def __call__(self, params, params_args, obj, idxs, alpha, prop_mode):
        params_dict = unflatten_dict(params, params_args)
        f, grad_dict = obj.objective_function(
            params_dict, idxs, alpha=alpha, prop_mode=prop_mode)
        g, _ = flatten_dict(grad_dict)
        g_is_fin = np.isfinite(g)
        if np.all(g_is_fin):
            self.previous_x = params
            return f, g
        else:
            print("Warning: inf or nan in gradient: replacing with zeros")
            return f, np.where(g_is_fin, g, 0.)

# def objective_wrapper(params, params_args, obj, idxs, alpha):
#     params_dict = unflatten_dict(params, params_args)
#     f, grad_dict = obj.objective_function(
#         params_dict, idxs, alpha=alpha)
#     g, _ = flatten_dict(grad_dict)
#     g_is_fin = np.isfinite(g)
#     if np.all(g_is_fin):
#         return f, g
#     else:
#         print("Warning: inf or nan in gradient: replacing with zeros")
#         return f, np.where(g_is_fin, g, 0.)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号