def __init__(self, X, Y, batch_size,cropsize=0, truncate=False, sequential=False,
random=True, val=False, class_weights=None):
assert len(X) == len(Y), 'X and Y must be the same length {}!={}'.format(len(X),len(Y))
if sequential: print('Using sequential mode')
print ('starting normal generator')
self.X = X
self.Y = Y
self.rnd_idx = np.arange(len(Y))
self.Y_last_epoch = []
self.val = val
self.step = 0
self.i = 0
self.cropsize=cropsize
self.truncate = truncate
self.random = False if sequential or val else random
self.batch_size = int(batch_size)
self.sequential = sequential
self.c_weights = class_weights if class_weights else dict(zip(np.unique(np.argmax(Y,1)),np.ones(len(np.argmax(Y,1)))))
assert set(np.argmax(Y,1)) == set([int(x) for x in self.c_weights.keys()]), 'not all labels in class weights'
self.n_batches = int(len(X)//batch_size if truncate else np.ceil(len(X)/batch_size))
if self.random: self.randomize()
评论列表
文章目录