dataset.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:PyTorchText 作者: chenyuntc 项目源码 文件源码
def __init__(self,train_root,labels_file,type_='char',augument=True):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)

        # embedding_d = np.load(embedding_root)['vector']
        self.augument=augument
        question_d = np.load(train_root)
        self.type_=type_
        if type_ == 'char':
            all_data_title,all_data_content =\
                 question_d['title_char'],question_d['content_char']

        elif type_ == 'word':
            all_data_title,all_data_content =\
                 question_d['title_word'],question_d['content_word']

        self.train_data = all_data_title[:-200000],all_data_content[:-200000]
        self.val_data = all_data_title[-200000:],all_data_content[-200000:]

        self.all_num = len(all_data_content)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title)

        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']

        self.training=True
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号