def __init__(self,train_root,labels_file,type_='char'):
'''
Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
'''
import json
with open(labels_file) as f:
labels_ = json.load(f)
# embedding_d = np.load(embedding_root)['vector']
question_d = np.load(train_root)
self.type_=type_
if type_ == 'char':
all_data_title,all_data_content =\
question_d['title_char'],question_d['content_char']
elif type_ == 'word':
all_data_title,all_data_content =\
question_d['title_word'],question_d['content_word']
self.train_data = all_data_title[:-20000],all_data_content[:-20000]
self.val_data = all_data_title[-20000:],all_data_content[-20000:]
self.all_num = len(all_data_content)
# del all_data_title,all_data_content
self.data_title,self.data_content = self.train_data
self.len_ = len(self.data_title)
self.index2qid = question_d['index2qid'].item()
self.l_end=0
self.labels = labels_['d']
# def augument(self,d):
# '''
# ?????: ????
# '''
# if self.type_=='char':
# _index = (-8,8)
# else :_index =(-5,5)
# r = d.new(d.size()).fill_(0)
# index = random.randint(-3,4)
# if _index >0:
# r[index:] = d[:-index]
# else:
# r[:-index] = d[index:]
# return r
# def augument(self,d,type_=1):
# if type_==1:
# return self.shuffle(d)
# else :
# if self.type_=='char':
# return self.dropout(d,p=0.6)
评论列表
文章目录