def iterate_minibatches(self, batchsize, shuffle=True, train=True):
indices = []
if train:
indices = np.argwhere(np.in1d(data.labels, data.train_classes))
else:
indices = np.argwhere(np.logical_not(np.in1d(data.labels, data.train_classes)))
if shuffle:
np.random.shuffle(indices)
for start_idx in range(0, len(self.img_paths) - batchsize + 1, batchsize):
excerpt = indices[start_idx:start_idx + batchsize]
images = [self._load_preprocess_img(self.img_paths[int(i)]) for i in excerpt]
if len(images) == batchsize:
yield np.concatenate(images), np.array(self.labels[excerpt]).astype(np.int32).T
else:
raise StopIteration
评论列表
文章目录