cgan.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:shenlan 作者: vector-1127 项目源码 文件源码
def generate(BATCH_SIZE, nice=False):
    (X_train, Y_train) = get_data('test')
    #print(np.shape(X_train))
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    Y_train = (Y_train.astype(np.float32) - 127.5)/127.5

    generator = generator_model()
    generator.compile(loss='binary_crossentropy', optimizer="SGD")
    generator.load_weights('generator')
    if nice:
        discriminator = discriminator_model()
        discriminator.compile(loss='binary_crossentropy', optimizer="SGD")
        discriminator.load_weights('discriminator')

        generated_images = generator.predict(X_train, verbose=1)
        d_pret = discriminator.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE, 1) + (generated_images.shape[2:]), dtype=np.float32)
        for i in range(int(BATCH_SIZE)):
            idx = int(pre_with_index[i][1])
            nice_images[i, 0, :, :] = generated_images[idx, 0, :, :]
        image = combine_images(nice_images)
    else:
        generated_images = generator.predict(X_train)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    image = np.swapaxes(image,0,2)
    cv2.imwrite('generated.png',image)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号