def rotate_batch(self, x_batch, axis, k): x_batch = rotate(x_batch, k*90, reshape=False, axes=axis, mode="nearest") return x_batch