utils.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:brats17 作者: xf4j 项目源码 文件源码
def read_label(path, is_training=True):
    seg = nib.load(glob.glob(os.path.join(path, '*_seg.nii.gz'))[0]).get_data().astype(np.float32)
    # Crop to 128*128*64
    crop_size = (128, 128, 64)
    crop = [int((seg.shape[0] - crop_size[0]) / 2), int((seg.shape[1] - crop_size[1]) / 2),
            int((seg.shape[2] - crop_size[2]) / 2)]
    seg = seg[crop[0] : crop[0] + crop_size[0], crop[1] : crop[1] + crop_size[1], crop[2] : crop[2] + crop_size[2]]
    label = np.zeros((seg.shape[0], seg.shape[1], seg.shape[2], 3), dtype=np.float32)
    label[seg == 1, 0] = 1
    label[seg == 2, 1] = 1
    label[seg == 4, 2] = 1

    final_label = np.empty((16, 16, 16, 3), dtype=np.float32)
    for z in range(label.shape[3]):
        final_label[..., z] = resize(label[..., z], (16, 16, 16), mode='constant')

    # Augmentation
    if is_training:
        im_size = final_label.shape[:-1]
        translation = [np.random.uniform(-2, 2), np.random.uniform(-2, 2), np.random.uniform(-2, 2)]
        rotation = euler2mat(0, 0, np.random.uniform(-5, 5) / 180.0 * np.pi, 'sxyz')
        scale = [1, 1, 1]
        warp_mat = compose(translation, rotation, scale)
        tform_coords = get_tform_coords(im_size)
        w = np.dot(warp_mat, tform_coords)
        w[0] = w[0] + im_size[0] / 2
        w[1] = w[1] + im_size[1] / 2
        w[2] = w[2] + im_size[2] / 2
        warp_coords = w[0:3].reshape(3, im_size[0], im_size[1], im_size[2])
        for z in range(label.shape[3]):
            final_label[..., z] = warp(final_label[..., z], warp_coords)

    return final_label
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号