def decode(self, t, length, raw=False):
if length.numel() == 1:
length = length[0]
t = t[:length]
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
评论列表
文章目录