def __init__(self,train_root,labels_file,type_='char',augument=True):
'''
Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
'''
import json
with open(labels_file) as f:
labels_ = json.load(f)
# embedding_d = np.load(embedding_root)['vector']
self.augument=augument
question_d = np.load(train_root)
self.type_=type_
if type_ == 'char':
all_data_title,all_data_content =\
question_d['title_char'],question_d['content_char']
elif type_ == 'word':
all_data_title,all_data_content =\
question_d['title_word'],question_d['content_word']
self.train_data = all_data_title[:-200000],all_data_content[:-200000]
self.val_data = all_data_title[-200000:],all_data_content[-200000:]
self.all_num = len(all_data_content)
# del all_data_title,all_data_content
self.data_title,self.data_content = self.train_data
self.len_ = len(self.data_title)
self.index2qid = question_d['index2qid'].item()
self.l_end=0
self.labels = labels_['d']
self.training=True
评论列表
文章目录