def _split_data(self):
counts = np.zeros(self._num_classes)
labeled_indices = list()
num_per_class = int(self._num_labels / self._num_classes)
for i, l in enumerate(self._labels):
index = np.nonzero(l)[0][0]
if counts[index] < num_per_class:
counts[index] += 1
labeled_indices.append(i)
elif counts.sum() == self._num_labels:
break
else:
continue
all_indices = set(range(self._num_train_images))
unlabeled_indices = list(all_indices - set(labeled_indices))
images_labeled = self._images[labeled_indices]
images_unlabeled = self._images[unlabeled_indices]
labels = self._labels[labeled_indices]
return images_labeled, images_unlabeled, labels
评论列表
文章目录