def prepare_serialized_examples(self, serialized_examples, width=50, height=50):
# set the mapping from the fields to data types in the proto
feature_map = {
'image': tf.FixedLenFeature((), tf.string, default_value=''),
'label': tf.FixedLenFeature([], tf.int64)
}
features = tf.parse_example(serialized_examples, features=feature_map)
def decode_and_resize(image_str_tensor):
"""Decodes png string, resizes it and returns a uint8 tensor."""
# Output a grayscale (channels=1) image
image = tf.image.decode_png(image_str_tensor, channels=1)
# Note resize expects a batch_size, but tf_map supresses that index,
# thus we have to expand then squeeze. Resize returns float32 in the
# range [0, uint8_max]
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(
image, [height, width], align_corners=False)
image = tf.squeeze(image, squeeze_dims=[0])
image = tf.cast(image, dtype=tf.uint8)
return image
images_str_tensor = features["image"]
images = tf.map_fn(
decode_and_resize, images_str_tensor, back_prop=False, dtype=tf.uint8)
images = tf.image.convert_image_dtype(images, dtype=tf.float32)
images = tf.subtract(images, 0.5)
images = tf.multiply(images, 2.0)
def dense_to_one_hot(label_batch, num_classes):
one_hot = tf.map_fn(lambda x : tf.cast(slim.one_hot_encoding(x, num_classes), tf.int32), label_batch)
one_hot = tf.reshape(one_hot, [-1, num_classes])
return one_hot
labels = tf.cast(features['label'], tf.int32)
labels = dense_to_one_hot(labels, 10)
return images, labels
readers.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录