def next_batch(self, batch_size):
assert self.train_mode or self.validation_mode, "Please set mode, train, validation or test. e.g. DataLoad.train()"
idx_next_batch = [(self.current_idx + i)%len(self.p_imgs) for i in range(self.batch_size)]
patient_img_next_batch = [ self.p_imgs[idx] for idx in idx_next_batch]
batch_image = []
batch_mask = []
for image in patient_img_next_batch:
fi = gzip.open(self.data_path + image, 'rb')
img = pickle.load(fi)
img = np.expand_dims(img, axis=2)
batch_image.append(img)
fi.close()
fm = gzip.open(self.mask_path + image, 'rb')
mask = pickle.load(fm)
fm.close()
mask_binary_class = np.zeros([mask.shape[0],mask.shape[1],2])
mask_binary_class[:,:,0][mask == 0] = 1
mask_binary_class[:,:,1][mask == 1] = 1
batch_mask.append(mask_binary_class)
self.current_idx = (self.current_idx + batch_size) % len(self.p_imgs)
batched_image = np.stack(batch_image)
batched_mask = np.stack(batch_mask)
return batched_image, batched_mask
luna_preprocessed_load_data.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录