def reset(self):
""" Resets the state of the generator"""
self.step = 0
Y = np.argmax(self.Y,1)
labels = np.unique(Y)
idx = []
smallest = len(Y)
for i,label in enumerate(labels):
where = np.where(Y==label)[0]
if smallest > len(where):
self.slabel = i
smallest = len(where)
idx.append(where)
self.idx = idx
self.labels = labels
self.n_per_class = int(self.batch_size // len(labels))
self.n_batches = int(np.ceil((smallest//self.n_per_class)))+1
self.update_probabilities()
评论列表
文章目录