cars196.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:no_fuss_dml 作者: brotherofken 项目源码 文件源码
def iterate_minibatches(self, batchsize, shuffle=True, train=True):
        indices = []
        if train:
            indices = np.argwhere(np.in1d(data.labels, data.train_classes))
        else:
            indices = np.argwhere(np.logical_not(np.in1d(data.labels, data.train_classes)))

        if shuffle:
            np.random.shuffle(indices)

        for start_idx in range(0, len(self.img_paths) - batchsize + 1, batchsize):
            excerpt = indices[start_idx:start_idx + batchsize]
            images = [self._load_preprocess_img(self.img_paths[int(i)]) for i in excerpt]
            if len(images) == batchsize:
                yield np.concatenate(images), np.array(self.labels[excerpt]).astype(np.int32).T
            else:
                raise StopIteration
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号