CNN_for_tfrecords.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:gong_an_pictures 作者: oukohou 项目源码 文件源码
def read_decode_tfrecords(records_path, num_epochs=1020, batch_size=Flags.batch_size, num_threads=2):
    if gfile.IsDirectory(records_path):
        records_path = [os.path.join(records_path, i) for i in os.listdir(records_path)]
    else:
        records_path = [records_path]
    records_path_queue = tf.train.string_input_producer(records_path, seed=123,
                                                        num_epochs=num_epochs,
                                                        name="string_input_producer")
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(records_path_queue, name="serialized_example")
    features = tf.parse_single_example(serialized=serialized_example,
                                       features={"img_raw": tf.FixedLenFeature([], tf.string),
                                                 "label": tf.FixedLenFeature([], tf.int64),
                                                 "height": tf.FixedLenFeature([], tf.int64),
                                                 "width": tf.FixedLenFeature([], tf.int64),
                                                 "depth": tf.FixedLenFeature([], tf.int64)},
                                       name="parse_single_example")
    image = tf.decode_raw(features["img_raw"], tf.uint8, name="decode_raw")
    image.set_shape([height * width * 3])
    image = tf.cast(image, tf.float32) * (1.0 / 255) - 0.5
    label = tf.cast(features["label"], tf.int32)
    images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=num_threads,
                                            name="shuffle_bath", capacity=1020, min_after_dequeue=64)
    print("images' shape is :", str(images.shape))
    return images, labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号