def transform_mnist_rts(in_data):
img, label = in_data
img = img[0] # Remove channel axis for skimage manipulation
# Rotate
img = transform.rotate(img, angle=np.random.uniform(-45, 45),
resize=True, mode='constant')
# Scale
img = transform.rescale(img, scale=np.random.uniform(0.7, 1.2),
mode='constant')
# Translate
h, w = img.shape
if h >= img_size[0] or w >= img_size[1]:
img = transform.resize(img, output_shape=img_size, mode='constant')
img = img.astype(np.float32)
else:
img_canvas = np.zeros(img_size, dtype=np.float32)
ymin = np.random.randint(0, img_size[0] - h)
xmin = np.random.randint(0, img_size[1] - w)
img_canvas[ymin:ymin+h, xmin:xmin+w] = img
img = img_canvas
img = img[np.newaxis, :] # Add the bach channel back
return img, label
评论列表
文章目录