static_graph.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:static-define-by-run 作者: bkvogel 项目源码 文件源码
def static_backward(func):
    """Decorator to mark a function for inclusion in the backward schedule.

    The decorator is used in the same way as `static_forward` except that
    it is used to decorate the functions that should be added to the static
    backward-pass schedule. The wrapped function implements the
    computations of the `backward()` method of a Function instance that
    must be executed during every backward pass.

    Similarly to `static_forward`, the wrapped function should not return
    a result because it will be ignored.

    Args:
        func: A a backward-pass method of a sub-class of Function that should be inserted
        into the static backward schedule when `static_graph` is enabled. The function
        must not return a value because any return values will be ignored.

    Returns: The wrapped function.

    """
    def wrapped_func(*args, **kwargs):
        # Save arguments, function, and results pointers/references to the schedule list:
        def no_arg_func():
            #print('In no_arg_func: Calling: ', func)
            func(*args, **kwargs)
            #print("Arguments were: %s, %s" % (args, kwargs))

        # no_arg_func() requires no arguments to call since the arguments of the decorated function
        # are captured by the closure.
        no_arg_func()
        inst = args[0]
        assert isinstance(inst, function.Function)
        schedule_function = getattr(inst, 'schedule_func', None)
        # If trace mode is on, add to schedule.
        if schedule_function is not None:
            print('Adding function to the backward static schedule.')
            schedule_function.append_backward_function(no_arg_func)

    return wrapped_func
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号