def preprocess_images():
if os.path.isfile(IMG_PATH):
print("%s exists, nothing to do." % IMG_PATH)
return
print("Reading images from img_align_celeba/ ...")
raw_images = []
for i in range(1, N_IMAGES + 1):
if i % 10000 == 0:
print(i)
raw_images.append(mpimg.imread('img_align_celeba/%06i.jpg' % i)[20:-20])
if len(raw_images) != N_IMAGES:
raise Exception("Found %i images. Expected %i" % (len(raw_images), N_IMAGES))
print("Resizing images ...")
all_images = []
for i, image in enumerate(raw_images):
if i % 10000 == 0:
print(i)
assert image.shape == (178, 178, 3)
if IMG_SIZE < 178:
image = cv2.resize(image, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
elif IMG_SIZE > 178:
image = cv2.resize(image, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LANCZOS4)
assert image.shape == (IMG_SIZE, IMG_SIZE, 3)
all_images.append(image)
data = np.concatenate([img.transpose((2, 0, 1))[None] for img in all_images], 0)
data = torch.from_numpy(data)
assert data.size() == (N_IMAGES, 3, IMG_SIZE, IMG_SIZE)
print("Saving images to %s ..." % IMG_PATH)
torch.save(data[:20000].clone(), 'images_%i_%i_20000.pth' % (IMG_SIZE, IMG_SIZE))
torch.save(data, IMG_PATH)
评论列表
文章目录