dataloader.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Text-Summarization 作者: hashbangCoder 项目源码 文件源码
def getInputTextSample(self, tokenized_text):
        extIntArticles, intRevArticles = [], []
        max_article_oov = 0        
        # get article  int-tokenized
        _intArticle, _extIntArticle, article_oov, _ = self.makeEncoderInput(tokenized_text)
        if max_article_oov < len(article_oov):
            max_article_oov = len(article_oov)
        _intRevArticle = list(reversed(_intArticle))

        extIntArticles.append(_extIntArticle)            
        intRevArticles.append(_intRevArticle)

        padExtArticles = [torch.LongTensor(item) for item in extIntArticles]        
        padRevArticles = [torch.LongTensor(item) for item in intRevArticles]                

        batchExtArticles = torch.stack(padExtArticles, 0)
        # replace temp ids with unk token id for enc input
        batchArticles = batchExtArticles.clone().masked_fill_((batchExtArticles > self.vocabSize), self.word2id['<unk>'])
        batchRevArticles = torch.stack(padRevArticles, 0)

        return batchArticles, batchRevArticles, batchExtArticles, max_article_oov, article_oov
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号