def create_input(call_args, requires_grad=True, non_contiguous=False):
if not isinstance(call_args, tuple):
call_args = (call_args,)
def map_arg(arg):
def maybe_non_contig(tensor):
return tensor if not non_contiguous else make_non_contiguous(tensor)
if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
return arg
elif isinstance(arg, tuple) and not isinstance(arg[0], Variable):
return Variable(maybe_non_contig(torch.randn(*arg).double()), requires_grad=requires_grad)
elif torch.is_tensor(arg):
if isinstance(arg, torch.FloatTensor):
return Variable(maybe_non_contig(arg.double()), requires_grad=requires_grad)
else:
return Variable(maybe_non_contig(arg), requires_grad=requires_grad)
elif isinstance(arg, Variable) and non_contiguous:
return Variable(maybe_non_contig(arg.data), requires_grad=arg.requires_grad)
else:
return arg
return tuple(map_arg(arg) for arg in call_args)
评论列表
文章目录