def main(argv):
pattern = "/home/ian/imagenet/ILSVRC2012_img_train_t1_t2/n*/*JPEG"
files = glob(pattern)
assert len(files) > 0
assert len(files) > 1000000, len(files)
dirs = glob("/home/ian/imagenet/ILSVRC2012_img_train_t1_t2/n*")
assert len(dirs) == 1000, len(dirs)
dirs = [d.split('/')[-1] for d in dirs]
dirs = sorted(dirs)
str_to_int = dict(zip(dirs, range(len(dirs))))
outfile = '/media/NAS_SHARED/imagenet/imagenet_train_labeled_' + str(IMSIZE) + '.tfrecords'
writer = tf.python_io.TFRecordWriter(outfile)
for i, f in enumerate(files):
print i
image = get_image(f, IMSIZE, is_crop=True, resize_w=IMSIZE)
image = colorize(image)
assert image.shape == (IMSIZE, IMSIZE, 3)
image += 1.
image *= (255. / 2.)
image = image.astype('uint8')
#print image.min(), image.max()
# from pylearn2.utils.image import save
# save('foo.png', (image + 1.) / 2.)
image_raw = image.tostring()
class_str = f.split('/')[-2]
label = str_to_int[class_str]
if i % 1 == 0:
print i, '\t',label
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(IMSIZE),
'width': _int64_feature(IMSIZE),
'depth': _int64_feature(3),
'image_raw': _bytes_feature(image_raw),
'label': _int64_feature(label)
}))
writer.write(example.SerializeToString())
writer.close()
convert_imagenet_to_records.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录