mnist2_ssgan_trainer.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ssgan 作者: samrussell 项目源码 文件源码
def build_models(self, input_shape):
    middle_neurons = 10

    self.encoder = Sequential()
    self.encoder.add(Conv2D(64, (5, 5), strides=(2, 2), padding = 'same', input_shape=input_shape))
    self.encoder.add(Activation(selu))
    self.encoder.add(Conv2D(128, (5, 5), strides=(2, 2), padding = 'same'))
    self.encoder.add(Activation(selu))
    self.encoder.add(Flatten())
    self.encoder.add(Dense(middle_neurons))
    self.encoder.add(Activation('sigmoid'))
    self.encoder.summary()

    self.decoder = Sequential()
    self.decoder.add(Dense(7*7*128, input_shape=(middle_neurons,)))
    self.decoder.add(Activation(selu))
    if keras.backend.image_data_format() == 'channels_first':
        self.decoder.add(Reshape([128, 7, 7]))
    else:    
        self.decoder.add(Reshape([7, 7, 128]))
    self.decoder.add(UpSampling2D(size=(2, 2)))
    self.decoder.add(Conv2D(64, (5, 5), padding='same'))
    self.decoder.add(Activation(selu))
    self.decoder.add(UpSampling2D(size=(2, 2)))
    self.decoder.add(Conv2D(1, (5, 5), padding='same'))
    self.decoder.add(Activation('sigmoid'))
    self.decoder.summary()

    self.autoencoder = Sequential()
    self.autoencoder.add(self.encoder)
    self.autoencoder.add(self.decoder)
    self.autoencoder.compile(loss='mean_squared_error',
                                  optimizer=Adam(lr=1e-4),
                                  metrics=['accuracy'])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号