def make_non_contiguous(tensor):
osize = list(tensor.size())
# randomly inflate a few dimensions in osize
for _ in range(2):
dim = random.randint(0, len(osize) - 1)
add = random.randint(4, 15)
osize[dim] = osize[dim] + add
# narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension,
# (which will always happen with a 1-dimensional tensor), so let's make a new
# right-most dimension and cut it off
input = tensor.new(torch.Size(osize + [random.randint(2, 3)]))
input = input.select(len(input.size()) - 1, random.randint(0, 1))
# now extract the input of correct size from 'input'
for i in range(len(osize)):
if input.size(i) != tensor.size(i):
bounds = random.randint(1, input.size(i) - tensor.size(i))
input = input.narrow(i, bounds, tensor.size(i))
input.copy_(tensor)
return input
评论列表
文章目录