def load_data(opt):
with open(opt["squad_dir"]+'meta.msgpack', 'rb') as f:
meta = msgpack.load(f, encoding='utf8')
embedding = meta['embedding']
opt['pretrained_words'] = True
opt['vocab_size'] = len(embedding)
opt['embedding_dim'] = len(embedding[0])
with open(args.data_file, 'rb') as f:
data = msgpack.load(f, encoding='utf8')
#with open(opt["squad_dir"]+ 'train.csv', 'rb') as f:
# charResult = chardet.detect(f.read())
train_orig = pd.read_csv(opt["squad_dir"]+ 'train.csv')#, encoding=charResult['encoding'])
dev_orig = pd.read_csv(opt["squad_dir"]+'dev.csv')#, encoding=charResult['encoding'])
train = list(zip(
data['trn_context_ids'],data['trn_context_features'],
data['trn_context_tags'],data['trn_context_ents'],data['trn_question_ids'],
train_orig['answer_start_token'].tolist(), train_orig['answer_end_token'].tolist(),
data['trn_context_text'],data['trn_context_spans']
))
dev = list(zip(
data['dev_context_ids'],data['dev_context_features'],data['dev_context_tags'],
data['dev_context_ents'],data['dev_question_ids'],data['dev_context_text'],data['dev_context_spans']
))
dev_y = dev_orig['answers'].tolist()[:len(dev)]
dev_y = [eval(y) for y in dev_y]
# discover lengths
opt['context_len'] = get_max_len(data['trn_context_ids'], data['dev_context_ids'])
opt['feature_len'] = get_max_len(data['trn_context_features'][0], data['dev_context_features'][0])
opt['question_len'] = get_max_len(data['trn_question_ids'], data['dev_question_ids'])
print(train_orig['answer_start_token'].tolist()[:10])
return train, dev, dev_y, embedding, opt
评论列表
文章目录