def test_MaxUnpool2d_output_size(self):
m = nn.MaxPool2d(3, stride=2, return_indices=True)
mu = nn.MaxUnpool2d(3, stride=2)
big_t = torch.rand(1, 1, 6, 6)
big_t[0][0][4][4] = 100
output_big, indices_big = m(Variable(big_t))
self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big))
small_t = torch.rand(1, 1, 5, 5)
for i in range(0, 4, 2):
for j in range(0, 4, 2):
small_t[:,:,i,j] = 100
output_small, indices_small = m(Variable(small_t))
for h in range(3, 10):
for w in range(3, 10):
if 4 <= h <= 6 and 4 <= w <= 6:
size = (h, w)
if h == 5:
size = torch.LongStorage(size)
elif h == 6:
size = torch.LongStorage((1, 1) + size)
mu(output_small, indices_small, output_size=size)
else:
self.assertRaises(ValueError, lambda:
mu(output_small, indices_small, (h, w)))
python类LongStorage()的实例源码
def make_tensor_reader(typename):
python_class = get_python_class(typename)
def read_tensor(reader, version):
# source:
# https://github.com/torch/torch7/blob/master/generic/Tensor.c#L1243
ndim = reader.read_int()
# read size:
size = torch.LongStorage(reader.read_long_array(ndim))
# read stride:
stride = torch.LongStorage(reader.read_long_array(ndim))
# storage offset:
storage_offset = reader.read_long() - 1
# read storage:
storage = reader.read()
if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0:
# empty torch tensor
return python_class()
return python_class().set_(storage, storage_offset, torch.Size(size), tuple(stride))
return read_tensor
def test_MaxUnpool2d_output_size(self):
m = nn.MaxPool2d(3, stride=2, return_indices=True)
mu = nn.MaxUnpool2d(3, stride=2)
big_t = torch.rand(1, 1, 6, 6)
big_t[0][0][4][4] = 100
output_big, indices_big = m(Variable(big_t))
self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big))
small_t = torch.rand(1, 1, 5, 5)
for i in range(0, 4, 2):
for j in range(0, 4, 2):
small_t[:, :, i, j] = 100
output_small, indices_small = m(Variable(small_t))
for h in range(3, 10):
for w in range(3, 10):
if 4 <= h <= 6 and 4 <= w <= 6:
size = (h, w)
if h == 5:
size = torch.LongStorage(size)
elif h == 6:
size = torch.LongStorage((1, 1) + size)
mu(output_small, indices_small, output_size=size)
else:
self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w)))
def make_tensor_reader(typename):
python_class = get_python_class(typename)
def read_tensor(reader, version):
# source:
# https://github.com/torch/torch7/blob/master/generic/Tensor.c#L1243
ndim = reader.read_int()
# read size:
size = torch.LongStorage(reader.read_long_array(ndim))
# read stride:
stride = torch.LongStorage(reader.read_long_array(ndim))
# storage offset:
storage_offset = reader.read_long() - 1
# read storage:
storage = reader.read()
if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0:
# empty torch tensor
return python_class()
return python_class().set_(storage, storage_offset, torch.Size(size), tuple(stride))
return read_tensor
def test_MaxUnpool2d_output_size(self):
m = nn.MaxPool2d(3, stride=2, return_indices=True)
mu = nn.MaxUnpool2d(3, stride=2)
big_t = torch.rand(1, 1, 6, 6)
big_t[0][0][4][4] = 100
output_big, indices_big = m(Variable(big_t))
self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big))
small_t = torch.rand(1, 1, 5, 5)
for i in range(0, 4, 2):
for j in range(0, 4, 2):
small_t[:, :, i, j] = 100
output_small, indices_small = m(Variable(small_t))
for h in range(3, 10):
for w in range(3, 10):
if 4 <= h <= 6 and 4 <= w <= 6:
size = (h, w)
if h == 5:
size = torch.LongStorage(size)
elif h == 6:
size = torch.LongStorage((1, 1) + size)
mu(output_small, indices_small, output_size=size)
else:
self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w)))
def make_tensor_reader(typename):
python_class = get_python_class(typename)
def read_tensor(reader, version):
# source:
# https://github.com/torch/torch7/blob/master/generic/Tensor.c#L1243
ndim = reader.read_int()
# read size:
size = torch.LongStorage(reader.read_long_array(ndim))
# read stride:
stride = torch.LongStorage(reader.read_long_array(ndim))
# storage offset:
storage_offset = reader.read_long() - 1
# read storage:
storage = reader.read()
if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0:
# empty torch tensor
return python_class()
return python_class().set_(storage, storage_offset, torch.Size(size), tuple(stride))
return read_tensor
def test_MaxUnpool2d_output_size(self):
m = nn.MaxPool2d(3, stride=2, return_indices=True)
mu = nn.MaxUnpool2d(3, stride=2)
big_t = torch.rand(1, 1, 6, 6)
big_t[0][0][4][4] = 100
output_big, indices_big = m(Variable(big_t))
self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big))
small_t = torch.rand(1, 1, 5, 5)
for i in range(0, 4, 2):
for j in range(0, 4, 2):
small_t[:, :, i, j] = 100
output_small, indices_small = m(Variable(small_t))
for h in range(3, 10):
for w in range(3, 10):
if 4 <= h <= 6 and 4 <= w <= 6:
size = (h, w)
if h == 5:
size = torch.LongStorage(size)
elif h == 6:
size = torch.LongStorage((1, 1) + size)
mu(output_small, indices_small, output_size=size)
else:
self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w)))
def make_tensor_reader(typename):
python_class = get_python_class(typename)
def read_tensor(reader, version):
# source:
# https://github.com/torch/torch7/blob/master/generic/Tensor.c#L1243
ndim = reader.read_int()
# read size:
size = torch.LongStorage(reader.read_long_array(ndim))
# read stride:
stride = torch.LongStorage(reader.read_long_array(ndim))
# storage offset:
storage_offset = reader.read_long() - 1
# read storage:
storage = reader.read()
if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0:
# empty torch tensor
return python_class()
return python_class().set_(storage, storage_offset, torch.Size(size), tuple(stride))
return read_tensor
def test_MaxUnpool2d_output_size(self):
m = nn.MaxPool2d(3, stride=2, return_indices=True)
mu = nn.MaxUnpool2d(3, stride=2)
big_t = torch.rand(1, 1, 6, 6)
big_t[0][0][4][4] = 100
output_big, indices_big = m(Variable(big_t))
self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big))
small_t = torch.rand(1, 1, 5, 5)
for i in range(0, 4, 2):
for j in range(0, 4, 2):
small_t[:, :, i, j] = 100
output_small, indices_small = m(Variable(small_t))
for h in range(3, 10):
for w in range(3, 10):
if 4 <= h <= 6 and 4 <= w <= 6:
size = (h, w)
if h == 5:
size = torch.LongStorage(size)
elif h == 6:
size = torch.LongStorage((1, 1) + size)
mu(output_small, indices_small, output_size=size)
else:
self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w)))
def test_repeat(self):
initial_shape = (8, 4)
tensor = torch.rand(*initial_shape)
size = (3, 1, 1)
torchSize = torch.Size(size)
target = [3, 8, 4]
self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat')
self.assertEqual(tensor.repeat(torchSize).size(), target,
'Error in repeat using LongStorage')
result = tensor.repeat(*size)
self.assertEqual(result.size(), target, 'Error in repeat using result')
result = tensor.repeat(torchSize)
self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage')
self.assertEqual(result.mean(0).view(8, 4), tensor, 'Error in repeat (not equal)')
def test_ConvTranspose2d_output_size(self):
m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
i = Variable(torch.randn(2, 3, 6, 6))
for h in range(15, 22):
for w in range(15, 22):
if 18 <= h <= 20 and 18 <= w <= 20:
size = (h, w)
if h == 19:
size = torch.LongStorage(size)
elif h == 2:
size = torch.LongStorage((2, 4) + size)
m(i, output_size=(h, w))
else:
self.assertRaises(ValueError, lambda: m(i, (h, w)))
def test_repeat(self):
result = torch.Tensor()
tensor = torch.rand(8, 4)
size = (3, 1, 1)
torchSize = torch.Size(size)
target = [3, 8, 4]
self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat')
self.assertEqual(tensor.repeat(torchSize).size(), target, 'Error in repeat using LongStorage')
result = tensor.repeat(*size)
self.assertEqual(result.size(), target, 'Error in repeat using result')
result = tensor.repeat(torchSize)
self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage')
self.assertEqual((result.mean(0).view(8, 4)-tensor).abs().max(), 0, 'Error in repeat (not equal)')
def test_element_size(self):
byte = torch.ByteStorage().element_size()
char = torch.CharStorage().element_size()
short = torch.ShortStorage().element_size()
int = torch.IntStorage().element_size()
long = torch.LongStorage().element_size()
float = torch.FloatStorage().element_size()
double = torch.DoubleStorage().element_size()
self.assertEqual(byte, torch.ByteTensor().element_size())
self.assertEqual(char, torch.CharTensor().element_size())
self.assertEqual(short, torch.ShortTensor().element_size())
self.assertEqual(int, torch.IntTensor().element_size())
self.assertEqual(long, torch.LongTensor().element_size())
self.assertEqual(float, torch.FloatTensor().element_size())
self.assertEqual(double, torch.DoubleTensor().element_size())
self.assertGreater(byte, 0)
self.assertGreater(char, 0)
self.assertGreater(short, 0)
self.assertGreater(int, 0)
self.assertGreater(long, 0)
self.assertGreater(float, 0)
self.assertGreater(double, 0)
# These tests are portable, not necessarily strict for your system.
self.assertEqual(byte, 1)
self.assertEqual(char, 1)
self.assertGreaterEqual(short, 2)
self.assertGreaterEqual(int, 2)
self.assertGreaterEqual(int, short)
self.assertGreaterEqual(long, 4)
self.assertGreaterEqual(long, int)
self.assertGreaterEqual(double, float)
def test_repeat(self):
result = torch.Tensor()
tensor = torch.rand(8, 4)
size = (3, 1, 1)
torchSize = torch.Size(size)
target = [3, 8, 4]
self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat')
self.assertEqual(tensor.repeat(torchSize).size(), target, 'Error in repeat using LongStorage')
result = tensor.repeat(*size)
self.assertEqual(result.size(), target, 'Error in repeat using result')
result = tensor.repeat(torchSize)
self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage')
self.assertEqual((result.mean(0).view(8, 4) - tensor).abs().max(), 0, 'Error in repeat (not equal)')
def test_element_size(self):
byte = torch.ByteStorage().element_size()
char = torch.CharStorage().element_size()
short = torch.ShortStorage().element_size()
int = torch.IntStorage().element_size()
long = torch.LongStorage().element_size()
float = torch.FloatStorage().element_size()
double = torch.DoubleStorage().element_size()
self.assertEqual(byte, torch.ByteTensor().element_size())
self.assertEqual(char, torch.CharTensor().element_size())
self.assertEqual(short, torch.ShortTensor().element_size())
self.assertEqual(int, torch.IntTensor().element_size())
self.assertEqual(long, torch.LongTensor().element_size())
self.assertEqual(float, torch.FloatTensor().element_size())
self.assertEqual(double, torch.DoubleTensor().element_size())
self.assertGreater(byte, 0)
self.assertGreater(char, 0)
self.assertGreater(short, 0)
self.assertGreater(int, 0)
self.assertGreater(long, 0)
self.assertGreater(float, 0)
self.assertGreater(double, 0)
# These tests are portable, not necessarily strict for your system.
self.assertEqual(byte, 1)
self.assertEqual(char, 1)
self.assertGreaterEqual(short, 2)
self.assertGreaterEqual(int, 2)
self.assertGreaterEqual(int, short)
self.assertGreaterEqual(long, 4)
self.assertGreaterEqual(long, int)
self.assertGreaterEqual(double, float)
def test_repeat(self):
result = torch.Tensor()
tensor = torch.rand(8, 4)
size = (3, 1, 1)
torchSize = torch.Size(size)
target = [3, 8, 4]
self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat')
self.assertEqual(tensor.repeat(torchSize).size(), target, 'Error in repeat using LongStorage')
result = tensor.repeat(*size)
self.assertEqual(result.size(), target, 'Error in repeat using result')
result = tensor.repeat(torchSize)
self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage')
self.assertEqual((result.mean(0).view(8, 4) - tensor).abs().max(), 0, 'Error in repeat (not equal)')
def test_element_size(self):
byte = torch.ByteStorage().element_size()
char = torch.CharStorage().element_size()
short = torch.ShortStorage().element_size()
int = torch.IntStorage().element_size()
long = torch.LongStorage().element_size()
float = torch.FloatStorage().element_size()
double = torch.DoubleStorage().element_size()
self.assertEqual(byte, torch.ByteTensor().element_size())
self.assertEqual(char, torch.CharTensor().element_size())
self.assertEqual(short, torch.ShortTensor().element_size())
self.assertEqual(int, torch.IntTensor().element_size())
self.assertEqual(long, torch.LongTensor().element_size())
self.assertEqual(float, torch.FloatTensor().element_size())
self.assertEqual(double, torch.DoubleTensor().element_size())
self.assertGreater(byte, 0)
self.assertGreater(char, 0)
self.assertGreater(short, 0)
self.assertGreater(int, 0)
self.assertGreater(long, 0)
self.assertGreater(float, 0)
self.assertGreater(double, 0)
# These tests are portable, not necessarily strict for your system.
self.assertEqual(byte, 1)
self.assertEqual(char, 1)
self.assertGreaterEqual(short, 2)
self.assertGreaterEqual(int, 2)
self.assertGreaterEqual(int, short)
self.assertGreaterEqual(long, 4)
self.assertGreaterEqual(long, int)
self.assertGreaterEqual(double, float)
def test_repeat(self):
result = torch.Tensor()
tensor = torch.rand(8, 4)
size = (3, 1, 1)
torchSize = torch.Size(size)
target = [3, 8, 4]
self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat')
self.assertEqual(tensor.repeat(torchSize).size(), target, 'Error in repeat using LongStorage')
result = tensor.repeat(*size)
self.assertEqual(result.size(), target, 'Error in repeat using result')
result = tensor.repeat(torchSize)
self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage')
self.assertEqual((result.mean(0).view(8, 4) - tensor).abs().max(), 0, 'Error in repeat (not equal)')
def test_element_size(self):
byte = torch.ByteStorage().element_size()
char = torch.CharStorage().element_size()
short = torch.ShortStorage().element_size()
int = torch.IntStorage().element_size()
long = torch.LongStorage().element_size()
float = torch.FloatStorage().element_size()
double = torch.DoubleStorage().element_size()
self.assertEqual(byte, torch.ByteTensor().element_size())
self.assertEqual(char, torch.CharTensor().element_size())
self.assertEqual(short, torch.ShortTensor().element_size())
self.assertEqual(int, torch.IntTensor().element_size())
self.assertEqual(long, torch.LongTensor().element_size())
self.assertEqual(float, torch.FloatTensor().element_size())
self.assertEqual(double, torch.DoubleTensor().element_size())
self.assertGreater(byte, 0)
self.assertGreater(char, 0)
self.assertGreater(short, 0)
self.assertGreater(int, 0)
self.assertGreater(long, 0)
self.assertGreater(float, 0)
self.assertGreater(double, 0)
# These tests are portable, not necessarily strict for your system.
self.assertEqual(byte, 1)
self.assertEqual(char, 1)
self.assertGreaterEqual(short, 2)
self.assertGreaterEqual(int, 2)
self.assertGreaterEqual(int, short)
self.assertGreaterEqual(long, 4)
self.assertGreaterEqual(long, int)
self.assertGreaterEqual(double, float)
def test_element_size(self):
byte = torch.ByteStorage().element_size()
char = torch.CharStorage().element_size()
short = torch.ShortStorage().element_size()
int = torch.IntStorage().element_size()
long = torch.LongStorage().element_size()
float = torch.FloatStorage().element_size()
double = torch.DoubleStorage().element_size()
self.assertEqual(byte, torch.ByteTensor().element_size())
self.assertEqual(char, torch.CharTensor().element_size())
self.assertEqual(short, torch.ShortTensor().element_size())
self.assertEqual(int, torch.IntTensor().element_size())
self.assertEqual(long, torch.LongTensor().element_size())
self.assertEqual(float, torch.FloatTensor().element_size())
self.assertEqual(double, torch.DoubleTensor().element_size())
self.assertGreater(byte, 0)
self.assertGreater(char, 0)
self.assertGreater(short, 0)
self.assertGreater(int, 0)
self.assertGreater(long, 0)
self.assertGreater(float, 0)
self.assertGreater(double, 0)
# These tests are portable, not necessarily strict for your system.
self.assertEqual(byte, 1)
self.assertEqual(char, 1)
self.assertGreaterEqual(short, 2)
self.assertGreaterEqual(int, 2)
self.assertGreaterEqual(int, short)
self.assertGreaterEqual(long, 4)
self.assertGreaterEqual(long, int)
self.assertGreaterEqual(double, float)
def _tensor_str(self):
n = PRINT_OPTS.edgeitems
has_hdots = self.size()[-1] > 2*n
has_vdots = self.size()[-2] > 2*n
print_full_mat = not has_hdots and not has_vdots
formatter = _number_format(self, min_sz=3 if not print_full_mat else 0)
print_dots = self.numel() >= PRINT_OPTS.threshold
dim_sz = max(2, max(len(str(x)) for x in self.size()))
dim_fmt = "{:^" + str(dim_sz) + "}"
dot_fmt = u"{:^" + str(dim_sz+1) + "}"
counter_dim = self.ndimension() - 2
counter = torch.LongStorage(counter_dim).fill_(0)
counter[counter.size()-1] = -1
finished = False
strt = ''
while True:
nrestarted = [False for i in counter]
nskipped = [False for i in counter]
for i in _range(counter_dim - 1, -1, -1):
counter[i] += 1
if print_dots and counter[i] == n and self.size(i) > 2*n:
counter[i] = self.size(i) - n
nskipped[i] = True
if counter[i] == self.size(i):
if i == 0:
finished = True
counter[i] = 0
nrestarted[i] = True
else:
break
if finished:
break
elif print_dots:
if any(nskipped):
for hdot in nskipped:
strt += dot_fmt.format('...') if hdot \
else dot_fmt.format('')
strt += '\n'
if any(nrestarted):
strt += ' '
for vdot in nrestarted:
strt += dot_fmt.format(u'\u22EE' if vdot else '')
strt += '\n'
if strt != '':
strt += '\n'
strt += '({},.,.) = \n'.format(
','.join(dim_fmt.format(i) for i in counter))
submatrix = reduce(lambda t, i: t.select(0, i), counter, self)
strt += _matrix_str(submatrix, ' ', formatter, print_dots)
return strt
def _tensor_str(self):
n = PRINT_OPTS.edgeitems
has_hdots = self.size()[-1] > 2 * n
has_vdots = self.size()[-2] > 2 * n
print_full_mat = not has_hdots and not has_vdots
formatter = _number_format(self, min_sz=3 if not print_full_mat else 0)
print_dots = self.numel() >= PRINT_OPTS.threshold
dim_sz = max(2, max(len(str(x)) for x in self.size()))
dim_fmt = "{:^" + str(dim_sz) + "}"
dot_fmt = u"{:^" + str(dim_sz + 1) + "}"
counter_dim = self.ndimension() - 2
counter = torch.LongStorage(counter_dim).fill_(0)
counter[counter.size() - 1] = -1
finished = False
strt = ''
while True:
nrestarted = [False for i in counter]
nskipped = [False for i in counter]
for i in _range(counter_dim - 1, -1, -1):
counter[i] += 1
if print_dots and counter[i] == n and self.size(i) > 2 * n:
counter[i] = self.size(i) - n
nskipped[i] = True
if counter[i] == self.size(i):
if i == 0:
finished = True
counter[i] = 0
nrestarted[i] = True
else:
break
if finished:
break
elif print_dots:
if any(nskipped):
for hdot in nskipped:
strt += dot_fmt.format('...') if hdot \
else dot_fmt.format('')
strt += '\n'
if any(nrestarted):
strt += ' '
for vdot in nrestarted:
strt += dot_fmt.format(u'\u22EE' if vdot else '')
strt += '\n'
if strt != '':
strt += '\n'
strt += '({},.,.) = \n'.format(
','.join(dim_fmt.format(i) for i in counter))
submatrix = reduce(lambda t, i: t.select(0, i), counter, self)
strt += _matrix_str(submatrix, ' ', formatter, print_dots)
return strt
def _tensor_str(self):
n = PRINT_OPTS.edgeitems
has_hdots = self.size()[-1] > 2 * n
has_vdots = self.size()[-2] > 2 * n
print_full_mat = not has_hdots and not has_vdots
formatter = _number_format(self, min_sz=3 if not print_full_mat else 0)
print_dots = self.numel() >= PRINT_OPTS.threshold
dim_sz = max(2, max(len(str(x)) for x in self.size()))
dim_fmt = "{:^" + str(dim_sz) + "}"
dot_fmt = u"{:^" + str(dim_sz + 1) + "}"
counter_dim = self.ndimension() - 2
counter = torch.LongStorage(counter_dim).fill_(0)
counter[counter.size() - 1] = -1
finished = False
strt = ''
while True:
nrestarted = [False for i in counter]
nskipped = [False for i in counter]
for i in _range(counter_dim - 1, -1, -1):
counter[i] += 1
if print_dots and counter[i] == n and self.size(i) > 2 * n:
counter[i] = self.size(i) - n
nskipped[i] = True
if counter[i] == self.size(i):
if i == 0:
finished = True
counter[i] = 0
nrestarted[i] = True
else:
break
if finished:
break
elif print_dots:
if any(nskipped):
for hdot in nskipped:
strt += dot_fmt.format('...') if hdot \
else dot_fmt.format('')
strt += '\n'
if any(nrestarted):
strt += ' '
for vdot in nrestarted:
strt += dot_fmt.format(u'\u22EE' if vdot else '')
strt += '\n'
if strt != '':
strt += '\n'
strt += '({},.,.) = \n'.format(
','.join(dim_fmt.format(i) for i in counter))
submatrix = reduce(lambda t, i: t.select(0, i), counter, self)
strt += _matrix_str(submatrix, ' ', formatter, print_dots)
return strt
def _tensor_str(self):
n = PRINT_OPTS.edgeitems
has_hdots = self.size()[-1] > 2 * n
has_vdots = self.size()[-2] > 2 * n
print_full_mat = not has_hdots and not has_vdots
formatter = _number_format(self, min_sz=3 if not print_full_mat else 0)
print_dots = self.numel() >= PRINT_OPTS.threshold
dim_sz = max(2, max(len(str(x)) for x in self.size()))
dim_fmt = "{:^" + str(dim_sz) + "}"
dot_fmt = u"{:^" + str(dim_sz + 1) + "}"
counter_dim = self.ndimension() - 2
counter = torch.LongStorage(counter_dim).fill_(0)
counter[counter.size() - 1] = -1
finished = False
strt = ''
while True:
nrestarted = [False for i in counter]
nskipped = [False for i in counter]
for i in _range(counter_dim - 1, -1, -1):
counter[i] += 1
if print_dots and counter[i] == n and self.size(i) > 2 * n:
counter[i] = self.size(i) - n
nskipped[i] = True
if counter[i] == self.size(i):
if i == 0:
finished = True
counter[i] = 0
nrestarted[i] = True
else:
break
if finished:
break
elif print_dots:
if any(nskipped):
for hdot in nskipped:
strt += dot_fmt.format('...') if hdot \
else dot_fmt.format('')
strt += '\n'
if any(nrestarted):
strt += ' '
for vdot in nrestarted:
strt += dot_fmt.format(u'\u22EE' if vdot else '')
strt += '\n'
if strt != '':
strt += '\n'
strt += '({},.,.) = \n'.format(
','.join(dim_fmt.format(i) for i in counter))
submatrix = reduce(lambda t, i: t.select(0, i), counter, self)
strt += _matrix_str(submatrix, ' ', formatter, print_dots)
return strt
def _tensor_str(self):
n = PRINT_OPTS.edgeitems
has_hdots = self.size()[-1] > 2 * n
has_vdots = self.size()[-2] > 2 * n
print_full_mat = not has_hdots and not has_vdots
formatter = _number_format(self, min_sz=3 if not print_full_mat else 0)
print_dots = self.numel() >= PRINT_OPTS.threshold
dim_sz = max(2, max(len(str(x)) for x in self.size()))
dim_fmt = "{:^" + str(dim_sz) + "}"
dot_fmt = u"{:^" + str(dim_sz + 1) + "}"
counter_dim = self.ndimension() - 2
counter = torch.LongStorage(counter_dim).fill_(0)
counter[counter.size() - 1] = -1
finished = False
strt = ''
while True:
nrestarted = [False for i in counter]
nskipped = [False for i in counter]
for i in _range(counter_dim - 1, -1, -1):
counter[i] += 1
if print_dots and counter[i] == n and self.size(i) > 2 * n:
counter[i] = self.size(i) - n
nskipped[i] = True
if counter[i] == self.size(i):
if i == 0:
finished = True
counter[i] = 0
nrestarted[i] = True
else:
break
if finished:
break
elif print_dots:
if any(nskipped):
for hdot in nskipped:
strt += dot_fmt.format('...') if hdot \
else dot_fmt.format('')
strt += '\n'
if any(nrestarted):
strt += ' '
for vdot in nrestarted:
strt += dot_fmt.format(u'\u22EE' if vdot else '')
strt += '\n'
if strt != '':
strt += '\n'
strt += '({},.,.) = \n'.format(
','.join(dim_fmt.format(i) for i in counter))
submatrix = reduce(lambda t, i: t.select(0, i), counter, self)
strt += _matrix_str(submatrix, ' ', formatter, print_dots)
return strt