def read_label(path, is_training=True):
seg = nib.load(glob.glob(os.path.join(path, '*_seg.nii.gz'))[0]).get_data().astype(np.float32)
# Crop to 128*128*64
crop_size = (128, 128, 64)
crop = [int((seg.shape[0] - crop_size[0]) / 2), int((seg.shape[1] - crop_size[1]) / 2),
int((seg.shape[2] - crop_size[2]) / 2)]
seg = seg[crop[0] : crop[0] + crop_size[0], crop[1] : crop[1] + crop_size[1], crop[2] : crop[2] + crop_size[2]]
label = np.zeros((seg.shape[0], seg.shape[1], seg.shape[2], 3), dtype=np.float32)
label[seg == 1, 0] = 1
label[seg == 2, 1] = 1
label[seg == 4, 2] = 1
final_label = np.empty((16, 16, 16, 3), dtype=np.float32)
for z in range(label.shape[3]):
final_label[..., z] = resize(label[..., z], (16, 16, 16), mode='constant')
# Augmentation
if is_training:
im_size = final_label.shape[:-1]
translation = [np.random.uniform(-2, 2), np.random.uniform(-2, 2), np.random.uniform(-2, 2)]
rotation = euler2mat(0, 0, np.random.uniform(-5, 5) / 180.0 * np.pi, 'sxyz')
scale = [1, 1, 1]
warp_mat = compose(translation, rotation, scale)
tform_coords = get_tform_coords(im_size)
w = np.dot(warp_mat, tform_coords)
w[0] = w[0] + im_size[0] / 2
w[1] = w[1] + im_size[1] / 2
w[2] = w[2] + im_size[2] / 2
warp_coords = w[0:3].reshape(3, im_size[0], im_size[1], im_size[2])
for z in range(label.shape[3]):
final_label[..., z] = warp(final_label[..., z], warp_coords)
return final_label
评论列表
文章目录