def image_loader(self):
if self.cur_ + self.batch_size >= len(self.images_):
self.cur_ = 0
self.perm_ = np.random.permutation(np.range(len(self.images_)))
img_batch = np.zeros([self.batch_size, self.img_h, self.img_w, 3])
for i, idx in enumerate(self.perm_[self.cur_: self.cur_+self.batch_size]):
img = transform.resize(io.imread(self.images_[idx]), (self.img_h, self.img_w))
img_batch[i, ...] = img
self.cur_ += self.batch_size
return img_batch
评论列表
文章目录