def generate(batch_size, pretty=False):
generator = generator_model()
generator.compile(loss='binary_crossentropy', optimizer="SGD")
generator.load_weights('generator_weights')
if pretty:
discriminator = discriminator_model()
discriminator.compile(loss='binary_crossentropy', optimizer="SGD")
discriminator.load_weights('discriminator_weights')
noise = np.zeros((batch_size*20, 100))
for i in range(batch_size*20):
noise[i, :] = np.random.uniform(-1, 1, 100)
generated_images = generator.predict(noise, verbose=1)
d_pret = discriminator.predict(generated_images, verbose=1)
index = np.arange(0, batch_size*20)
index.resize((batch_size*20, 1))
pre_with_index = list(np.append(d_pret, index, axis=1))
pre_with_index.sort(key=lambda x: x[0], reverse=True)
pretty_images = np.zeros((batch_size, 1) +
(generated_images.shape[2:]), dtype=np.float32)
for i in range(int(batch_size)):
idx = int(pre_with_index[i][1])
pretty_images[i, 0, :, :] = generated_images[idx, 0, :, :]
image = combine_images(pretty_images)
else:
noise = np.zeros((batch_size, 100))
for i in range(batch_size):
noise[i, :] = np.random.uniform(-1, 1, 100)
generated_images = generator.predict(noise, verbose=1)
image = combine_images(generated_images)
image = image*127.5+127.5
Image.fromarray(image.astype(np.uint8)).save(
"images/generated_image.png")
评论列表
文章目录