def read_dataset(data_dir):
pickle_filename = "celebA.pickle"
pickle_filepath = os.path.join(data_dir, pickle_filename)
if not os.path.exists(pickle_filepath):
# utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True)
celebA_folder = os.path.splitext(DATA_URL.split("/")[-1])[0]
dir_path = os.path.join(data_dir, celebA_folder)
if not os.path.exists(dir_path):
print ("CelebA dataset needs to be downloaded and unzipped manually")
print ("Download from: %s" % DATA_URL)
raise ValueError("Dataset not found")
result = create_image_lists(dir_path)
print ("Training set: %d" % len(result['train']))
print ("Test set: %d" % len(result['test']))
print ("Validation set: %d" % len(result['validation']))
print ("Pickling ...")
with open(pickle_filepath, 'wb') as f:
pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)
else:
print ("Found pickle file!")
with open(pickle_filepath, 'rb') as f:
result = pickle.load(f)
celebA = CelebA_Dataset(result)
del result
return celebA
评论列表
文章目录