def augment(images, labels=None, amplify=2):
# INPUT:
#images shape: (batch_size, height, width, channels=3)
#labels shape: (batch_size, 3)
ops = {
0: addBlotch,
1: shift,
2: addNoise,
3: rotate
}
shape = images.shape
new_images = np.zeros(((amplify*shape[0]), shape[1], shape[2], shape[3]))
if labels is not None:
new_labels = np.zeros(((amplify*shape[0]), 3))
for i in range(images.shape[0]):
cur_img = np.copy(images[i])
new_images[i] = cur_img
if labels is not None:
new_labels[i] = np.copy(labels[i])
for j in range(1, amplify):
add_r = ( j * shape[0] )
which_op = np.random.randint(low=0, high=4)
dup_img = np.zeros((1,shape[1], shape[2], shape[3]))
new_images[i+add_r] = ops[which_op](cur_img)
if labels is not None:
new_labels[i+add_r] = np.copy(labels[i])
if labels is not None:
return new_images.astype(np.uint8), new_labels.astype(np.uint8)
else:
return new_images.astype(np.uint8)
评论列表
文章目录