dcgan.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:dcgan 作者: kyloon 项目源码 文件源码
def generate(batch_size, pretty=False):
    generator = generator_model()
    generator.compile(loss='binary_crossentropy', optimizer="SGD")
    generator.load_weights('generator_weights')
    if pretty:
        discriminator = discriminator_model()
        discriminator.compile(loss='binary_crossentropy', optimizer="SGD")
        discriminator.load_weights('discriminator_weights')
        noise = np.zeros((batch_size*20, 100))
        for i in range(batch_size*20):
            noise[i, :] = np.random.uniform(-1, 1, 100)
        generated_images = generator.predict(noise, verbose=1)
        d_pret = discriminator.predict(generated_images, verbose=1)
        index = np.arange(0, batch_size*20)
        index.resize((batch_size*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        pretty_images = np.zeros((batch_size, 1) +
                               (generated_images.shape[2:]), dtype=np.float32)
        for i in range(int(batch_size)):
            idx = int(pre_with_index[i][1])
            pretty_images[i, 0, :, :] = generated_images[idx, 0, :, :]
        image = combine_images(pretty_images)
    else:
        noise = np.zeros((batch_size, 100))
        for i in range(batch_size):
            noise[i, :] = np.random.uniform(-1, 1, 100)
        generated_images = generator.predict(noise, verbose=1)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save(
        "images/generated_image.png")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号