def generate(BATCH_SIZE, nice=False):
(X_train, Y_train) = get_data('test')
#print(np.shape(X_train))
X_train = (X_train.astype(np.float32) - 127.5)/127.5
Y_train = (Y_train.astype(np.float32) - 127.5)/127.5
generator = generator_model()
generator.compile(loss='binary_crossentropy', optimizer="SGD")
generator.load_weights('generator')
if nice:
discriminator = discriminator_model()
discriminator.compile(loss='binary_crossentropy', optimizer="SGD")
discriminator.load_weights('discriminator')
generated_images = generator.predict(X_train, verbose=1)
d_pret = discriminator.predict(generated_images, verbose=1)
index = np.arange(0, BATCH_SIZE*20)
index.resize((BATCH_SIZE*20, 1))
pre_with_index = list(np.append(d_pret, index, axis=1))
pre_with_index.sort(key=lambda x: x[0], reverse=True)
nice_images = np.zeros((BATCH_SIZE, 1) + (generated_images.shape[2:]), dtype=np.float32)
for i in range(int(BATCH_SIZE)):
idx = int(pre_with_index[i][1])
nice_images[i, 0, :, :] = generated_images[idx, 0, :, :]
image = combine_images(nice_images)
else:
generated_images = generator.predict(X_train)
image = combine_images(generated_images)
image = image*127.5+127.5
image = np.swapaxes(image,0,2)
cv2.imwrite('generated.png',image)
评论列表
文章目录