def __read_imagenet(path, shuffle=True, save_file = 'imagenet_files.csv'):
if not os.path.exists(save_file):
def class_index(fn):
class_id = re.search(r'(n\d+)', fn).group(1)
return synset_map[class_id]['index']
file_list = glob.glob(path+'/*/*.JPEG')
label_indexes = []
with open(save_file, 'wb') as csv_file:
wr = csv.writer(csv_file, quoting=csv.QUOTE_NONE)
for f in file_list:
idx = class_index(f)
label_indexes.append(idx)
wr.writerow([f, idx])
with open(save_file, 'rb') as f:
reader = csv.reader(f)
file_list = list(reader)
file_tuple, label_tuple = zip(*file_list)
filename, labels = tf.train.slice_input_producer([list(file_tuple), list(label_tuple)], shuffle=shuffle)
images = tf.image.decode_jpeg(tf.read_file(filename), channels=3)
images = tf.div(tf.add(tf.to_float(images), -127), 128)
return images, tf.string_to_number(labels, tf.int32)
评论列表
文章目录