wgan_cnn.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:nn_playground 作者: DingKe 项目源码 文件源码
def build_generator(latent_size):
    cnn = Sequential()
    cnn.add(Dense(1024, input_dim=latent_size, activation='relu'))
    cnn.add(Dense(128 * 7 * 7, activation='relu'))
    cnn.add(Reshape((128, 7, 7)))

    # upsample to (..., 14, 14)
    cnn.add(UpSampling2D(size=(2, 2)))
    cnn.add(Conv2D(256, 5, padding='same',
                   activation='relu', kernel_initializer='glorot_normal'))

    # upsample to (..., 28, 28)
    cnn.add(UpSampling2D(size=(2, 2)))
    cnn.add(Conv2D(128, 5, padding='same',
                   activation='relu', kernel_initializer='glorot_normal'))

    # take a channel axis reduction
    cnn.add(Conv2D(1, 2, padding='same',
                   activation='tanh', kernel_initializer='glorot_normal'))

    # this is the z space commonly refered to in GAN papers
    latent = Input(shape=(latent_size,))

    fake_image = cnn(latent)

    return Model(inputs=latent, outputs=fake_image)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号