def once_differentiable(fn):
from .variable import Variable
@functools.wraps(fn)
def wrapper(ctx, *args):
tensor_args = [arg.data if isinstance(arg, Variable) else arg
for arg in args]
outputs = fn(ctx, *tensor_args)
# XXX: this is only an approximation of these flags - there's no way
# to figure out if fn didn't use ctx.saved_variables and as a result
# some Variables might require grad, even if no args do.
# Unfortunately, this leads to unexpected error messages ("no nodes
# require computing gradients"), but I don't have a better idea.
# These functions would raise an error in backward anyway.
volatile = any(arg.volatile if isinstance(arg, Variable) else False
for arg in args)
requires_grad = any(arg.requires_grad if isinstance(arg, Variable) else False
for arg in args)
if volatile:
def err_fn(*args):
return args
kwargs = {'volatile': True}
else:
err_fn = torch._C._functions.DelayedError(
b"trying to differentiate twice a function that was marked"
b"with @once_differentiable")
kwargs = {'requires_grad': requires_grad}
if not isinstance(outputs, tuple):
var = Variable(outputs, **kwargs) if outputs is not None else None
return err_fn(var)
return err_fn(*[Variable(o, **kwargs) if o is not None else None
for o in outputs])
return wrapper
评论列表
文章目录