def _batch_gen(self):
from random import sample,randint
self.DB.get_split()
epochs = self.flags.epochs
fold = self.flags.fold
if fold>=0:
docs_ids = list(self.DB.split[fold][0])
else:
docs_ids = list(range(self.DB.data['training_text'].shape[0]))
B = min(self.flags.batch_size,len(docs_ids))
batches_per_epoch = len(docs_ids)//B
y = self.DB.y
#print(batches_per_epoch)
for epoch in range(epochs):
for batch in range(batches_per_epoch):
inputs = []
labels = [] # 0 or 1
for idx in sample(docs_ids,B):
inputs.append(idx+1)
labels.append(y[idx])
yield inputs, labels, epoch
评论列表
文章目录