def _create_collate_fn(self, batch_first=True):
def collate(examples, this):
if this.split == "train":
items, question_ids, questions, questions_char, passages, passages_char, answers_positions, answer_texts, passage_tokenized = zip(
*examples)
else:
items, question_ids, questions, questions_char, passages, passages_char, passage_tokenized = zip(*examples)
questions_tensor, question_lengths = padding(questions, this.PAD, batch_first=batch_first)
passages_tensor, passage_lengths = padding(passages, this.PAD, batch_first=batch_first)
# TODO: implement char level embedding
question_document = Documents(questions_tensor, question_lengths)
passages_document = Documents(passages_tensor, passage_lengths)
if this.split == "train":
return question_ids, question_document, passages_document, torch.LongTensor(answers_positions), answer_texts
else:
return question_ids, question_document, passages_document, passage_tokenized
return partial(collate, this=self)
评论列表
文章目录