def load_images(image_h5_file, n_images=-1, shuffle_seed=1):
"""Load images and auxiliary data from h5 file.
Args:
image_h5_file: location of h5 file containing images.
n_images: number of images to load, -1 loads all.
auxvars: list of auxvar field names to load.
Returns:
images: array of image arrays.
aux_data: dict of auxvar arrays.
TODO: add support for multiple classes.
"""
with h5py.File(image_h5_file, 'r') as h5file:
images = h5file['images']
auxvars = h5file['auxvars']
if n_images < 0:
n_images = len(images)
elif n_images > len(images):
print("Cannot load {0} images. Only {1} images in {2}".format(
n_images, len(images), image_h5_file))
n_images = len(images)
if n_images < len(images):
rs = cross_validation.ShuffleSplit(
len(images), n_iter=1, test_size=n_images,
random_state=shuffle_seed)
for train, test in rs:
keep = test
images = np.take(images, keep, axis=0)
auxvars = np.take(auxvars, keep, axis=0)
else:
images = h5file['images'][:]
auxvars = h5file['auxvars'][:]
return images, auxvars
评论列表
文章目录