def check_image(image):
assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels")
with tf.control_dependencies([assertion]):
image = tf.identity(image)
if image.get_shape().ndims not in (3, 4):
raise ValueError("image must be either 3 or 4 dimensions")
# make the last dimension 3 so that you can unstack the colors
shape = list(image.get_shape())
shape[-1] = 3
image.set_shape(shape)
return image
# based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c
评论列表
文章目录