train.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:DrQA-TF 作者: zhaolewen 项目源码 文件源码
def load_data(opt):
    with open(opt["squad_dir"]+'meta.msgpack', 'rb') as f:
        meta = msgpack.load(f, encoding='utf8')
    embedding = meta['embedding']

    opt['pretrained_words'] = True
    opt['vocab_size'] = len(embedding)
    opt['embedding_dim'] = len(embedding[0])

    with open(args.data_file, 'rb') as f:
        data = msgpack.load(f, encoding='utf8')

    #with open(opt["squad_dir"]+ 'train.csv', 'rb') as f:
    #    charResult = chardet.detect(f.read())

    train_orig = pd.read_csv(opt["squad_dir"]+ 'train.csv')#, encoding=charResult['encoding'])
    dev_orig = pd.read_csv(opt["squad_dir"]+'dev.csv')#, encoding=charResult['encoding'])

    train = list(zip(
        data['trn_context_ids'],data['trn_context_features'],
        data['trn_context_tags'],data['trn_context_ents'],data['trn_question_ids'],
        train_orig['answer_start_token'].tolist(), train_orig['answer_end_token'].tolist(),
        data['trn_context_text'],data['trn_context_spans']
    ))
    dev = list(zip(
        data['dev_context_ids'],data['dev_context_features'],data['dev_context_tags'],
        data['dev_context_ents'],data['dev_question_ids'],data['dev_context_text'],data['dev_context_spans']
    ))
    dev_y = dev_orig['answers'].tolist()[:len(dev)]
    dev_y = [eval(y) for y in dev_y]

    # discover lengths
    opt['context_len'] = get_max_len(data['trn_context_ids'], data['dev_context_ids'])
    opt['feature_len'] = get_max_len(data['trn_context_features'][0], data['dev_context_features'][0])
    opt['question_len'] = get_max_len(data['trn_question_ids'], data['dev_question_ids'])

    print(train_orig['answer_start_token'].tolist()[:10])

    return train, dev, dev_y, embedding, opt
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号