def load_data(opt):
with open('SQuAD/meta.msgpack', 'rb') as f:
meta = msgpack.load(f, encoding='utf8')
embedding = torch.Tensor(meta['embedding'])
opt['pretrained_words'] = True
opt['vocab_size'] = embedding.size(0)
opt['embedding_dim'] = embedding.size(1)
if not opt['fix_embeddings']:
embedding[1] = torch.normal(means=torch.zeros(opt['embedding_dim']), std=1.)
with open(args.data_file, 'rb') as f:
data = msgpack.load(f, encoding='utf8')
train_orig = pd.read_csv('SQuAD/train.csv')
dev_orig = pd.read_csv('SQuAD/dev.csv')
train = list(zip(
data['trn_context_ids'],
data['trn_context_features'],
data['trn_context_tags'],
data['trn_context_ents'],
data['trn_question_ids'],
train_orig['answer_start_token'].tolist(),
train_orig['answer_end_token'].tolist(),
data['trn_context_text'],
data['trn_context_spans']
))
dev = list(zip(
data['dev_context_ids'],
data['dev_context_features'],
data['dev_context_tags'],
data['dev_context_ents'],
data['dev_question_ids'],
data['dev_context_text'],
data['dev_context_spans']
))
dev_y = dev_orig['answers'].tolist()[:len(dev)]
dev_y = [eval(y) for y in dev_y]
return train, dev, dev_y, embedding, opt
评论列表
文章目录