def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
if isinstance(text, str):
text = [
self.dict[char.lower() if self._ignore_case else char]
for char in text
]
length = [len(text)]
elif isinstance(text, collections.Iterable):
length = [len(s) for s in text]
text = ''.join(text)
text, _ = self.encode(text)
return (torch.IntTensor(text), torch.IntTensor(length))
评论列表
文章目录