lbfgs_optimizer.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:third_person_im 作者: bstadie 项目源码 文件源码
def update_opt(self, loss, target, inputs, extra_inputs=None, *args, **kwargs):
        """
        :param loss: Symbolic expression for the loss function.
        :param target: A parameterized object to optimize over. It should implement methods of the
        :class:`rllab.core.paramerized.Parameterized` class.
        :param leq_constraint: A constraint provided as a tuple (f, epsilon), of the form f(*inputs) <= epsilon.
        :param inputs: A list of symbolic variables as inputs
        :return: No return value.
        """

        self._target = target

        def get_opt_output():
            flat_grad = tensor_utils.flatten_tensor_variables(tf.gradients(loss, target.get_params(trainable=True)))
            return [tf.cast(loss, tf.float64), tf.cast(flat_grad, tf.float64)]

        if extra_inputs is None:
            extra_inputs = list()

        self._opt_fun = ext.lazydict(
            f_loss=lambda: tensor_utils.compile_function(inputs + extra_inputs, loss),
            f_opt=lambda: tensor_utils.compile_function(
                inputs=inputs + extra_inputs,
                outputs=get_opt_output(),
            )
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号