def __init__(self,train_root,labels_file,type_='char',fold=0):
'''
Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
'''
import json
with open(labels_file) as f:
labels_ = json.load(f)
self.fold=fold
# embedding_d = np.load(embedding_root)['vector']
question_d = np.load(train_root)
# all_data_title,all_data_content =\
all_char_title,all_char_content= question_d['title_char'],question_d['content_char']
# all_data_title,all_data_content =\
all_word_title,all_word_content= question_d['title_word'],question_d['content_word']
self.train_data = (all_char_title[:-200000],all_char_content[:-200000]),( all_word_title[:-200000],all_word_content[:-200000])
self.val_data = (all_char_title[-200000:],all_char_content[-200000:]), (all_word_title[-200000:],all_word_content[-200000:])
self.all_num = len(all_char_title)
# del all_data_title,all_data_content
self.data_title,self.data_content = self.train_data
self.len_ = len(self.data_title[0])
self.training=True
self.index2qid = question_d['index2qid'].item()
self.l_end=0
self.labels = labels_['d']
评论列表
文章目录