def mnist_discriminator(input_shape=(28, 28, 1), scale=1/4):
x0 = Input(input_shape)
x = Conv2D(int(128*scale), (3, 3), strides=(2, 2), padding='same')(x0)
x = InstanceNormalization()(x)
x = LeakyReLU()(x)
x = Conv2D(int(64*scale), (3, 3), strides=(2, 2), padding='same')(x)
x = InstanceNormalization()(x)
x = LeakyReLU()(x)
x = residual_block(x, scale, num_id=2)
x = residual_block(x, scale*2, num_id=3)
x = Conv2D(int(128*scale), (3, 3), strides=(2, 2), padding='same')(x)
x = InstanceNormalization()(x)
x = LeakyReLU()(x)
x = Conv2D(int(128*scale), (3, 3), strides=(2, 2), padding='same')(x)
x = InstanceNormalization()(x)
x = LeakyReLU()(x)
x = Conv2D(1, (3, 3), strides=(2, 2), padding='same')(x)
x = GlobalAveragePooling2D()(x) # Flatten
x = Activation('sigmoid')(x)
return Model(x0, x)
评论列表
文章目录