def get_batch_patches(img_clec, label_clec, patch_dim, batch_size, chn=1, flip_flag=True, rot_flag=True):
"""generate a batch of paired patches for training"""
batch_img = np.zeros([batch_size, patch_dim, patch_dim, patch_dim, chn]).astype('float32')
batch_label = np.zeros([batch_size, patch_dim, patch_dim, patch_dim]).astype('int32')
for k in range(batch_size):
# randomly select an image pair
rand_idx = np.arange(len(img_clec))
np.random.shuffle(rand_idx)
rand_img = img_clec[rand_idx[0]]
rand_label = label_clec[rand_idx[0]]
rand_img = rand_img.astype('float32')
rand_label = rand_label.astype('int32')
# randomly select a box anchor
l, w, h = rand_img.shape
l_rand = np.arange(l - patch_dim)
w_rand = np.arange(w - patch_dim)
h_rand = np.arange(h - patch_dim)
np.random.shuffle(l_rand)
np.random.shuffle(w_rand)
np.random.shuffle(h_rand)
pos = np.array([l_rand[0], w_rand[0], h_rand[0]])
# crop
img_temp = copy.deepcopy(rand_img[pos[0]:pos[0]+patch_dim, pos[1]:pos[1]+patch_dim, pos[2]:pos[2]+patch_dim])
# normalization
img_temp = img_temp/255.0
mean_temp = np.mean(img_temp)
dev_temp = np.std(img_temp)
img_norm = (img_temp - mean_temp) / dev_temp
label_temp = copy.deepcopy(rand_label[pos[0]:pos[0]+patch_dim, pos[1]:pos[1]+patch_dim, pos[2]:pos[2]+patch_dim])
# possible augmentation
# rotation
if rot_flag and np.random.random() > 0.65:
# print 'rotating patch...'
rand_angle = [-25, 25]
np.random.shuffle(rand_angle)
img_norm = rotate(img_norm, angle=rand_angle[0], axes=(1, 0), reshape=False, order=1)
label_temp = rotate(label_temp, angle=rand_angle[0], axes=(1, 0), reshape=False, order=0)
batch_img[k, :, :, :, chn-1] = img_norm
batch_label[k, :, :, :] = label_temp
return batch_img, batch_label
# calculate the cube information
评论列表
文章目录