def _build(self, img, transform_params): if len(img.get_shape()) == 3: img = img[..., tf.newaxis] grid_coords = self._warper(transform_params) return snt.resampler(img, grid_coords)