dataset.1.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:PyTorchText 作者: chenyuntc 项目源码 文件源码
def __init__(self,train_root,labels_file,type_='char'):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)

        # embedding_d = np.load(embedding_root)['vector']
        question_d = np.load(train_root)
        self.type_=type_
        if type_ == 'char':
            all_data_title,all_data_content =\
                 question_d['title_char'],question_d['content_char']

        elif type_ == 'word':
            all_data_title,all_data_content =\
                 question_d['title_word'],question_d['content_word']

        self.train_data = all_data_title[:-20000],all_data_content[:-20000]
        self.val_data = all_data_title[-20000:],all_data_content[-20000:]

        self.all_num = len(all_data_content)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title)

        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']

    # def augument(self,d):
    #     '''
    #     ?????:   ????
    #     '''
    #     if self.type_=='char':
    #         _index = (-8,8)
    #     else :_index =(-5,5)
    #     r = d.new(d.size()).fill_(0)
    #     index = random.randint(-3,4)
    #     if _index >0:
    #         r[index:] = d[:-index]
    #     else:
    #         r[:-index] = d[index:]
    #     return r

    # def augument(self,d,type_=1):
    #     if type_==1:
    #         return self.shuffle(d)
    #     else :
    #         if self.type_=='char':
    #             return self.dropout(d,p=0.6)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号