luna_preprocessed_load_data.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:lung-cancer-detector 作者: YichenGong 项目源码 文件源码
def next_batch(self, batch_size):
        assert self.train_mode or self.validation_mode, "Please set mode, train, validation or test. e.g. DataLoad.train()"
        idx_next_batch = [(self.current_idx + i)%len(self.p_imgs) for i in range(self.batch_size)]
        patient_img_next_batch = [ self.p_imgs[idx] for idx in idx_next_batch]
        batch_image = []
        batch_mask = []
        for image in patient_img_next_batch:
            fi = gzip.open(self.data_path + image, 'rb')
            img = pickle.load(fi)
            img = np.expand_dims(img, axis=2)
            batch_image.append(img)
            fi.close()
            fm = gzip.open(self.mask_path + image, 'rb')
            mask = pickle.load(fm)
            fm.close()
            mask_binary_class = np.zeros([mask.shape[0],mask.shape[1],2])
            mask_binary_class[:,:,0][mask == 0] = 1
            mask_binary_class[:,:,1][mask == 1] = 1
            batch_mask.append(mask_binary_class)
        self.current_idx = (self.current_idx + batch_size) % len(self.p_imgs)
        batched_image = np.stack(batch_image)
        batched_mask = np.stack(batch_mask)
        return batched_image, batched_mask
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号