def create_input(call_args, requires_grad=True):
if not isinstance(call_args, tuple):
call_args = (call_args,)
def map_arg(arg):
if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
return arg
elif isinstance(arg, tuple) and not isinstance(arg[0], Variable):
return Variable(torch.randn(*arg).double(), requires_grad=requires_grad)
elif torch.is_tensor(arg):
if isinstance(arg, torch.FloatTensor):
return Variable(arg.double(), requires_grad=requires_grad)
else:
return Variable(arg, requires_grad=requires_grad)
else:
return arg
return tuple(map_arg(arg) for arg in call_args)
评论列表
文章目录