iterative.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tensorforce 作者: reinforceio 项目源码 文件源码
def tf_solve(self, fn_x, x_init, *args):
        """
        Iteratively solves an equation/optimization for $x$ involving an expression $f(x)$.

        Args:
            fn_x: A callable returning an expression $f(x)$ given $x$.
            x_init: Initial solution guess $x_0$.
            *args: Additional solver-specific arguments.

        Returns:
            A solution $x$ to the problem as given by the solver.
        """
        self.fn_x = fn_x

        # Initialization step
        args = self.initialize(x_init, *args)

        # Iteration loop with termination condition
        if self.unroll_loop:
            # Unrolled for loop
            for _ in range(self.max_iterations):
                next_step = self.next_step(*args)
                step = (lambda: self.step(*args))
                do_nothing = (lambda: args)
                args = tf.cond(pred=next_step, true_fn=step, false_fn=do_nothing)

        else:
            # TensorFlow while loop
            args = tf.while_loop(cond=self.next_step, body=self.step, loop_vars=args)

        # First argument contains solution
        return args[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号