lbfgs_optimizer.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:third_person_im 作者: bstadie 项目源码 文件源码
def optimize(self, inputs, extra_inputs=None):
        f_opt = self._opt_fun["f_opt"]

        if extra_inputs is None:
            extra_inputs = list()

        def f_opt_wrapper(flat_params):
            self._target.set_param_values(flat_params, trainable=True)
            return f_opt(*inputs)

        itr = [0]
        start_time = time.time()

        if self._callback:
            def opt_callback(params):
                loss = self._opt_fun["f_loss"](*(inputs + extra_inputs))
                elapsed = time.time() - start_time
                self._callback(dict(
                    loss=loss,
                    params=params,
                    itr=itr[0],
                    elapsed=elapsed,
                ))
                itr[0] += 1
        else:
            opt_callback = None

        scipy.optimize.fmin_l_bfgs_b(
            func=f_opt_wrapper, x0=self._target.get_param_values(trainable=True),
            maxiter=self._max_opt_itr, callback=opt_callback,
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号