nn.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:kaggle-review 作者: daxiongshu 项目源码 文件源码
def _batch_gen(self):
        from random import sample,randint
        self.DB.get_split()
        epochs = self.flags.epochs
        fold = self.flags.fold

        if fold>=0:
            docs_ids = list(self.DB.split[fold][0])
        else:
            docs_ids = list(range(self.DB.data['training_text'].shape[0]))

        B = min(self.flags.batch_size,len(docs_ids))
        batches_per_epoch = len(docs_ids)//B

        y = self.DB.y
        #print(batches_per_epoch)
        for epoch in range(epochs):
            for batch in range(batches_per_epoch):
                inputs = []
                labels = [] # 0 or 1
                for idx in sample(docs_ids,B):
                    inputs.append(idx+1)
                    labels.append(y[idx])
                yield inputs, labels, epoch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号