def next_normal(self):
x_batch = self.X[self.step*self.batch_size:(self.step+1)*self.batch_size]
y_batch = self.Y[self.step*self.batch_size:(self.step+1)*self.batch_size]
diff = len(x_batch[0]) - self.cropsize
if self.cropsize!=0 and not self.val:
start = np.random.choice(np.arange(0,diff+5,5), len(x_batch))
x_batch = [x[start[i]:start[i]+self.cropsize,:] for i,x in enumerate(x_batch)]
elif self.cropsize !=0 and self.val:
x_batch = [x[diff//2:diff//2+self.cropsize] for i,x in enumerate(x_batch)]
x_batch = np.array(x_batch, dtype=np.float32)
y_batch = np.array(y_batch, dtype=np.int32)
self.step+=1
if self.val:
self.Y_last_epoch.extend(y_batch)
return x_batch # for validation generator, save the new y_labels
else:
weights = np.ones(len(y_batch))
for t in np.unique(np.argmax(y_batch,1)):
weights[np.argmax(y_batch,1)==t] = self.c_weights[t]
return (x_batch,y_batch)
评论列表
文章目录