def _imagenet_load_file(path, epochs=None, shuffle=True, seed=0, subset='train', prepare_path=True):
IMAGENET_ROOT = os.environ.get('IMAGENET_DIR', '')
if not isinstance(path, list):
path = [path]
filename_queue = tf.train.string_input_producer(path,
num_epochs=epochs, shuffle=shuffle, seed=seed)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
image_path, label_str = tf.decode_csv(value, record_defaults=[[''], ['']], field_delim=' ')
if prepare_path:
image_abspath = IMAGENET_ROOT + '/images/' + subset + image_path
else:
image_abspath = image_path
image_content = tf.read_file(image_abspath)
image = decode_image(image_content, channels=3)
image.set_shape([None, None, 3])
imgshape = tf.shape(image)[:2]
label = tf.string_to_number(label_str, out_type=tf.int32)
return image, label, imgshape, image_path
评论列表
文章目录