test_autograd.py 文件源码

python
阅读 50 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def create_input(call_args, requires_grad=True, non_contiguous=False):
    if not isinstance(call_args, tuple):
        call_args = (call_args,)

    def map_arg(arg):
        def maybe_non_contig(tensor):
            return tensor if not non_contiguous else make_non_contiguous(tensor)

        if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
            return arg
        elif isinstance(arg, tuple) and not isinstance(arg[0], Variable):
            return Variable(maybe_non_contig(torch.randn(*arg).double()), requires_grad=requires_grad)
        elif torch.is_tensor(arg):
            if isinstance(arg, torch.FloatTensor):
                return Variable(maybe_non_contig(arg.double()), requires_grad=requires_grad)
            else:
                return Variable(maybe_non_contig(arg), requires_grad=requires_grad)
        elif isinstance(arg, Variable) and non_contiguous:
            return Variable(maybe_non_contig(arg.data), requires_grad=arg.requires_grad)
        else:
            return arg
    return tuple(map_arg(arg) for arg in call_args)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号