def getInputTextSample(self, tokenized_text):
extIntArticles, intRevArticles = [], []
max_article_oov = 0
# get article int-tokenized
_intArticle, _extIntArticle, article_oov, _ = self.makeEncoderInput(tokenized_text)
if max_article_oov < len(article_oov):
max_article_oov = len(article_oov)
_intRevArticle = list(reversed(_intArticle))
extIntArticles.append(_extIntArticle)
intRevArticles.append(_intRevArticle)
padExtArticles = [torch.LongTensor(item) for item in extIntArticles]
padRevArticles = [torch.LongTensor(item) for item in intRevArticles]
batchExtArticles = torch.stack(padExtArticles, 0)
# replace temp ids with unk token id for enc input
batchArticles = batchExtArticles.clone().masked_fill_((batchExtArticles > self.vocabSize), self.word2id['<unk>'])
batchRevArticles = torch.stack(padRevArticles, 0)
return batchArticles, batchRevArticles, batchExtArticles, max_article_oov, article_oov
评论列表
文章目录