data.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:BinaryNet.tf 作者: itayhubara 项目源码 文件源码
def __read_imagenet(path, shuffle=True, save_file = 'imagenet_files.csv'):
    if not os.path.exists(save_file):
        def class_index(fn):
            class_id = re.search(r'(n\d+)', fn).group(1)
            return synset_map[class_id]['index']

        file_list = glob.glob(path+'/*/*.JPEG')
        label_indexes = []
        with open(save_file, 'wb') as csv_file:
            wr = csv.writer(csv_file, quoting=csv.QUOTE_NONE)
            for f in file_list:
                idx = class_index(f)
                label_indexes.append(idx)
                wr.writerow([f, idx])

    with open(save_file, 'rb') as f:
        reader = csv.reader(f)
        file_list = list(reader)
    file_tuple, label_tuple = zip(*file_list)

    filename, labels = tf.train.slice_input_producer([list(file_tuple), list(label_tuple)], shuffle=shuffle)
    images = tf.image.decode_jpeg(tf.read_file(filename), channels=3)
    images = tf.div(tf.add(tf.to_float(images), -127), 128)
    return images, tf.string_to_number(labels, tf.int32)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号