def postprocess_images(self, ims):
def _postprocess_images(im):
im = tf.decode_raw(im, np.uint8)
im = tf.image.convert_image_dtype(im, dtype=tf.float32)
im = tf.reshape(im, [256, 256, 3])
im = tf.random_crop(im, [self.crop_size, self.crop_size, 3])
return im
return tf.map_fn(lambda im: _postprocess_images(im), ims, dtype=tf.float32)
评论列表
文章目录