nvidia_input.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:ML-Study 作者: corona10 项目源码 文件源码
def read_raw_images(data_set):
    dirs = './data/'+data_set+'/'
    filename = list_binary_files(dirs)
    print filename
    filename_queue = tf.train.string_input_producer(filename)

    if data_set is 'train':
        image_bytes = FLAGS.height * FLAGS.width * FLAGS.depth
        record_bytes = image_bytes + 1
        reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
        key, value = reader.read(filename_queue)
        record_bytes = tf.decode_raw(value, tf.uint8)
        label = tf.cast(tf.slice(record_bytes, [0], [1]), tf.int32)
        depth_major = tf.reshape(tf.slice(record_bytes, [1], [image_bytes]),[FLAGS.depth, FLAGS.height, FLAGS.width])
        uint8image = tf.transpose(depth_major, [1, 2, 0])
        return label, uint8image
    elif data_set is 'test':
        image_bytes = FLAGS.height * FLAGS.width * FLAGS.depth
        record_bytes = image_bytes + 1
        reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
        key, value = reader.read(filename_queue)
        record_bytes = tf.decode_raw(value, tf.uint8)
        depth_major = tf.reshape(tf.slice(record_bytes, [0], [image_bytes]),
        [FLAGS.depth, FLAGS.height, FLAGS.width])
        uint8image = tf.transpose(depth_major, [1, 2, 0])
        return uint8image
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号