dataset.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:R-net 作者: matthew-z 项目源码 文件源码
def _create_collate_fn(self, batch_first=True):

        def collate(examples, this):
            if this.split == "train":
                items, question_ids, questions, questions_char, passages, passages_char, answers_positions, answer_texts, passage_tokenized = zip(
                    *examples)
            else:
                items, question_ids, questions, questions_char, passages, passages_char, passage_tokenized = zip(*examples)

            questions_tensor, question_lengths = padding(questions, this.PAD, batch_first=batch_first)
            passages_tensor, passage_lengths = padding(passages, this.PAD, batch_first=batch_first)

            # TODO: implement char level embedding

            question_document = Documents(questions_tensor, question_lengths)
            passages_document = Documents(passages_tensor, passage_lengths)

            if this.split == "train":
                return question_ids, question_document, passages_document, torch.LongTensor(answers_positions), answer_texts

            else:
                return question_ids, question_document, passages_document, passage_tokenized

        return partial(collate, this=self)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号