def _tensor_str(self):
n = PRINT_OPTS.edgeitems
has_hdots = self.size()[-1] > 2 * n
has_vdots = self.size()[-2] > 2 * n
print_full_mat = not has_hdots and not has_vdots
formatter = _number_format(self, min_sz=3 if not print_full_mat else 0)
print_dots = self.numel() >= PRINT_OPTS.threshold
dim_sz = max(2, max(len(str(x)) for x in self.size()))
dim_fmt = "{:^" + str(dim_sz) + "}"
dot_fmt = u"{:^" + str(dim_sz + 1) + "}"
counter_dim = self.ndimension() - 2
counter = torch.LongStorage(counter_dim).fill_(0)
counter[counter.size() - 1] = -1
finished = False
strt = ''
while True:
nrestarted = [False for i in counter]
nskipped = [False for i in counter]
for i in _range(counter_dim - 1, -1, -1):
counter[i] += 1
if print_dots and counter[i] == n and self.size(i) > 2 * n:
counter[i] = self.size(i) - n
nskipped[i] = True
if counter[i] == self.size(i):
if i == 0:
finished = True
counter[i] = 0
nrestarted[i] = True
else:
break
if finished:
break
elif print_dots:
if any(nskipped):
for hdot in nskipped:
strt += dot_fmt.format('...') if hdot \
else dot_fmt.format('')
strt += '\n'
if any(nrestarted):
strt += ' '
for vdot in nrestarted:
strt += dot_fmt.format(u'\u22EE' if vdot else '')
strt += '\n'
if strt != '':
strt += '\n'
strt += '({},.,.) = \n'.format(
','.join(dim_fmt.format(i) for i in counter))
submatrix = reduce(lambda t, i: t.select(0, i), counter, self)
strt += _matrix_str(submatrix, ' ', formatter, print_dots)
return strt
评论列表
文章目录