def next_batch(self, batch_size):
batches_ids = set()
while len(batches_ids) < batch_size:
h = random.randint(0, self.t_h-self.h)
w = random.randint(0, self.t_w-self.w)
d = random.randint(0, self.t_d-self.d)
batches_ids.add((h, w, d))
image_batches = []
label_batches = []
for h, w, d in batches_ids:
image_batches.append(
self.images[h:h+self.h, w:w+self.w, d:d+self.d])
label_batches.append(
self.labels[h:h+self.h, w:w+self.w, d:d+self.d])
images = np.expand_dims(np.stack(image_batches, axis=0), axis=-1)
images = np.transpose(images, (0, 3, 1, 2, 4))
labels = np.stack(label_batches, axis=0)
labels = np.transpose(labels, (0, 3, 1, 2))
return images, labels
评论列表
文章目录