utils.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:miccai17-mmwhs-hybrid 作者: xy0806 项目源码 文件源码
def get_batch_patches(img_clec, label_clec, patch_dim, batch_size, chn=1, flip_flag=True, rot_flag=True):
    """generate a batch of paired patches for training"""
    batch_img = np.zeros([batch_size, patch_dim, patch_dim, patch_dim, chn]).astype('float32')
    batch_label = np.zeros([batch_size, patch_dim, patch_dim, patch_dim]).astype('int32')

    for k in range(batch_size):
        # randomly select an image pair
        rand_idx = np.arange(len(img_clec))
        np.random.shuffle(rand_idx)
        rand_img = img_clec[rand_idx[0]]
        rand_label = label_clec[rand_idx[0]]
        rand_img = rand_img.astype('float32')
        rand_label = rand_label.astype('int32')

        # randomly select a box anchor
        l, w, h = rand_img.shape
        l_rand = np.arange(l - patch_dim)
        w_rand = np.arange(w - patch_dim)
        h_rand = np.arange(h - patch_dim)
        np.random.shuffle(l_rand)
        np.random.shuffle(w_rand)
        np.random.shuffle(h_rand)
        pos = np.array([l_rand[0], w_rand[0], h_rand[0]])
        # crop
        img_temp = copy.deepcopy(rand_img[pos[0]:pos[0]+patch_dim, pos[1]:pos[1]+patch_dim, pos[2]:pos[2]+patch_dim])
        # normalization
        img_temp = img_temp/255.0
        mean_temp = np.mean(img_temp)
        dev_temp = np.std(img_temp)
        img_norm = (img_temp - mean_temp) / dev_temp

        label_temp = copy.deepcopy(rand_label[pos[0]:pos[0]+patch_dim, pos[1]:pos[1]+patch_dim, pos[2]:pos[2]+patch_dim])

        # possible augmentation
        # rotation
        if rot_flag and np.random.random() > 0.65:
            # print 'rotating patch...'
            rand_angle = [-25, 25]
            np.random.shuffle(rand_angle)
            img_norm = rotate(img_norm, angle=rand_angle[0], axes=(1, 0), reshape=False, order=1)
            label_temp = rotate(label_temp, angle=rand_angle[0], axes=(1, 0), reshape=False, order=0)

        batch_img[k, :, :, :, chn-1] = img_norm
        batch_label[k, :, :, :] = label_temp

    return batch_img, batch_label


# calculate the cube information
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号