models.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:cyclegan_keras 作者: shadySource 项目源码 文件源码
def mnist_discriminator(input_shape=(28, 28, 1), scale=1/4):
    x0 = Input(input_shape)
    x = Conv2D(int(128*scale), (3, 3), strides=(2, 2), padding='same')(x0)
    x = InstanceNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(int(64*scale), (3, 3), strides=(2, 2), padding='same')(x)
    x = InstanceNormalization()(x)
    x = LeakyReLU()(x)
    x = residual_block(x, scale, num_id=2)
    x = residual_block(x, scale*2, num_id=3)
    x = Conv2D(int(128*scale), (3, 3), strides=(2, 2), padding='same')(x)
    x = InstanceNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(int(128*scale), (3, 3), strides=(2, 2), padding='same')(x)
    x = InstanceNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(1, (3, 3), strides=(2, 2), padding='same')(x)
    x = GlobalAveragePooling2D()(x) # Flatten
    x = Activation('sigmoid')(x)
    return Model(x0, x)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号