function.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def once_differentiable(fn):
    from .variable import Variable

    @functools.wraps(fn)
    def wrapper(ctx, *args):
        tensor_args = [arg.data if isinstance(arg, Variable) else arg
                       for arg in args]
        outputs = fn(ctx, *tensor_args)
        # XXX: this is only an approximation of these flags - there's no way
        # to figure out if fn didn't use ctx.saved_variables and as a result
        # some Variables might require grad, even if no args do.
        # Unfortunately, this leads to unexpected error messages ("no nodes
        # require computing gradients"), but I don't have a better idea.
        # These functions would raise an error in backward anyway.
        volatile = any(arg.volatile if isinstance(arg, Variable) else False
                       for arg in args)
        requires_grad = any(arg.requires_grad if isinstance(arg, Variable) else False
                            for arg in args)
        if volatile:
            def err_fn(*args):
                return args
            kwargs = {'volatile': True}
        else:
            err_fn = torch._C._functions.DelayedError(
                b"trying to differentiate twice a function that was marked"
                b"with @once_differentiable")
            kwargs = {'requires_grad': requires_grad}
        if not isinstance(outputs, tuple):
            var = Variable(outputs, **kwargs) if outputs is not None else None
            return err_fn(var)
        return err_fn(*[Variable(o, **kwargs) if o is not None else None
                      for o in outputs])
    return wrapper
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号