test_autograd.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def create_input(call_args, requires_grad=True):
    if not isinstance(call_args, tuple):
        call_args = (call_args,)

    def map_arg(arg):
        if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
            return arg
        elif isinstance(arg, tuple) and not isinstance(arg[0], Variable):
            return Variable(torch.randn(*arg).double(), requires_grad=requires_grad)
        elif torch.is_tensor(arg):
            if isinstance(arg, torch.FloatTensor):
                return Variable(arg.double(), requires_grad=requires_grad)
            else:
                return Variable(arg, requires_grad=requires_grad)
        else:
            return arg
    return tuple(map_arg(arg) for arg in call_args)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号