def __iter__(self):
for batch in self.data:
batch_size = len(batch)
batch = list(zip(*batch))
if self.eval:
assert len(batch) == 7
else:
assert len(batch) == 9
context_len = max(len(x) for x in batch[0])
context_id = torch.LongTensor(batch_size, context_len).fill_(0)
for i, doc in enumerate(batch[0]):
context_id[i, :len(doc)] = torch.LongTensor(doc)
feature_len = len(batch[1][0][0])
context_feature = torch.Tensor(batch_size, context_len, feature_len).fill_(0)
for i, doc in enumerate(batch[1]):
for j, feature in enumerate(doc):
context_feature[i, j, :] = torch.Tensor(feature)
context_tag = torch.LongTensor(batch_size, context_len).fill_(0)
for i, doc in enumerate(batch[2]):
context_tag[i, :len(doc)] = torch.LongTensor(doc)
context_ent = torch.LongTensor(batch_size, context_len).fill_(0)
for i, doc in enumerate(batch[3]):
context_ent[i, :len(doc)] = torch.LongTensor(doc)
question_len = max(len(x) for x in batch[4])
question_id = torch.LongTensor(batch_size, question_len).fill_(0)
for i, doc in enumerate(batch[4]):
question_id[i, :len(doc)] = torch.LongTensor(doc)
context_mask = torch.eq(context_id, 0)
question_mask = torch.eq(question_id, 0)
if not self.eval:
y_s = torch.LongTensor(batch[5])
y_e = torch.LongTensor(batch[6])
text = list(batch[-2])
span = list(batch[-1])
if self.gpu:
context_id = context_id.pin_memory()
context_feature = context_feature.pin_memory()
context_tag = context_tag.pin_memory()
context_ent = context_ent.pin_memory()
context_mask = context_mask.pin_memory()
question_id = question_id.pin_memory()
question_mask = question_mask.pin_memory()
if self.eval:
yield (context_id, context_feature, context_tag, context_ent, context_mask,
question_id, question_mask, text, span)
else:
yield (context_id, context_feature, context_tag, context_ent, context_mask,
question_id, question_mask, y_s, y_e, text, span)
评论列表
文章目录