dataloader.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Text-Summarization 作者: hashbangCoder 项目源码 文件源码
def evalPreproc(self, sample):   
        # sample length = 1
        # limit max article size to 400 tokens     
        extIntArticles, intRevArticles = [], []
        max_article_oov = 0        
        article = sample['article'].split(' ')                  
        # get article  int-tokenized
        _intArticle, _extIntArticle, article_oov, _ = self.makeEncoderInput(article)
        if max_article_oov < len(article_oov):
            max_article_oov = len(article_oov)
        _intRevArticle = list(reversed(_intArticle))
        # _intAbstract, _extIntAbstract, abs_len = self.makeDecoderInput(abstract, article_oov)

        extIntArticles.append(_extIntArticle)            
        intRevArticles.append(_intRevArticle)

        padExtArticles = [torch.LongTensor(item) for item in extIntArticles]        
        padRevArticles = [torch.LongTensor(item) for item in intRevArticles]                

        batchExtArticles = torch.stack(padExtArticles, 0)
        # replace temp ids with unk token id for enc input
        batchArticles = batchExtArticles.clone().masked_fill_((batchExtArticles > self.vocabSize), self.word2id['<unk>'])
        batchRevArticles = torch.stack(padRevArticles, 0)

        return batchArticles, batchRevArticles, batchExtArticles, max_article_oov, article_oov, sample['article'], sample['abstract']
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号