mnist.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:TensorFlow-ADGM 作者: dancsalo 项目源码 文件源码
def _split_data(self):
        counts = np.zeros(self._num_classes)
        labeled_indices = list()
        num_per_class = int(self._num_labels / self._num_classes)
        for i, l in enumerate(self._labels):
            index = np.nonzero(l)[0][0]
            if counts[index] < num_per_class:
                counts[index] += 1
                labeled_indices.append(i)
            elif counts.sum() == self._num_labels:
                break
            else:
                continue
        all_indices = set(range(self._num_train_images))
        unlabeled_indices = list(all_indices - set(labeled_indices))
        images_labeled = self._images[labeled_indices]
        images_unlabeled = self._images[unlabeled_indices]
        labels = self._labels[labeled_indices]
        return images_labeled, images_unlabeled, labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号