def static_backward(func):
"""Decorator to mark a function for inclusion in the backward schedule.
The decorator is used in the same way as `static_forward` except that
it is used to decorate the functions that should be added to the static
backward-pass schedule. The wrapped function implements the
computations of the `backward()` method of a Function instance that
must be executed during every backward pass.
Similarly to `static_forward`, the wrapped function should not return
a result because it will be ignored.
Args:
func: A a backward-pass method of a sub-class of Function that should be inserted
into the static backward schedule when `static_graph` is enabled. The function
must not return a value because any return values will be ignored.
Returns: The wrapped function.
"""
def wrapped_func(*args, **kwargs):
# Save arguments, function, and results pointers/references to the schedule list:
def no_arg_func():
#print('In no_arg_func: Calling: ', func)
func(*args, **kwargs)
#print("Arguments were: %s, %s" % (args, kwargs))
# no_arg_func() requires no arguments to call since the arguments of the decorated function
# are captured by the closure.
no_arg_func()
inst = args[0]
assert isinstance(inst, function.Function)
schedule_function = getattr(inst, 'schedule_func', None)
# If trace mode is on, add to schedule.
if schedule_function is not None:
print('Adding function to the backward static schedule.')
schedule_function.append_backward_function(no_arg_func)
return wrapped_func
评论列表
文章目录