def read_decode_tfrecords(records_path, num_epochs=1020, batch_size=Flags.batch_size, num_threads=2):
if gfile.IsDirectory(records_path):
records_path = [os.path.join(records_path, i) for i in os.listdir(records_path)]
else:
records_path = [records_path]
records_path_queue = tf.train.string_input_producer(records_path, seed=123,
# num_epochs=num_epochs,
name="string_input_producer")
reader = tf.TFRecordReader()
_, serialized_example = reader.read(records_path_queue, name="serialized_example")
features = tf.parse_single_example(serialized=serialized_example,
features={"img_raw": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64),
"height": tf.FixedLenFeature([], tf.int64),
"width": tf.FixedLenFeature([], tf.int64),
"depth": tf.FixedLenFeature([], tf.int64)},
name="parse_single_example")
image = tf.decode_raw(features["img_raw"], tf.uint8, name="decode_raw")
image.set_shape([IMAGE_PIXELS])
image = tf.cast(image, tf.float32) * (1.0 / 255) - 0.5
label = tf.cast(features["label"], tf.int32)
# images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=num_threads,
# name="shuffle_bath", capacity=1020, min_after_dequeue=64)
return image, label
test_CNN_with_checkpoints.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录