def deep_model1(input_shape):
input_img = Input(shape=input_shape)
print 'input shape:', input_img._keras_shape
# 32, 32
x = Conv2D(32, 3, 3, activation='relu', border_mode='same', subsample=(2, 2))(input_img)
x = BatchNormalization(mode=2, axis=3)(x)
# 16, 16
x = Conv2D(64, 3, 3, activation='relu', border_mode='same', subsample=(2, 2))(x)
x = BatchNormalization(mode=2, axis=3)(x)
# 8, 8
x = Conv2D(128, 3, 3, activation='relu', border_mode='same', subsample=(2, 2))(x)
x = BatchNormalization(mode=2, axis=3)(x)
# 4, 4
latent_dim = (1, 1, 1024)
z_mean = Conv2D(1024, 4, 4, activation='linear',
border_mode='same', subsample=(4, 4))(x)
# z_mean = GaussianNoise(0.1)(z_mean)
# TODO: the next layer use 16K parameters, will it be a problem?
z_log_var = Conv2D(1024, 4, 4, activation='linear',
border_mode='same', subsample=(4, 4))(x)
z = Lambda(sampling_gaussian, output_shape=latent_dim)([z_mean, z_log_var])
print 'encoded shape:', z._keras_shape
# x = BatchNormalization(mode=2, axis=3)(z)
batch_size = tf.shape(z)[0]
h, w, _ = z._keras_shape[1:]
# dim: (1, 1, 512)
x = Deconv2D(512, 4, 4, output_shape=[batch_size, 4, 4, 512],
activation='relu', border_mode='same', subsample=(4, 4))(z)
x = BatchNormalization(mode=2, axis=3)(x)
# (4, 4, 512)
x = Deconv2D(256, 5, 5, output_shape=[batch_size, 8, 8, 256],
activation='relu', border_mode='same', subsample=(2, 2))(x)
x = BatchNormalization(mode=2, axis=3)(x)
# dim: (8, 8, 256)
x = Deconv2D(128, 5, 5, output_shape=(batch_size, 16, 16, 128),
activation='relu', border_mode='same', subsample=(2, 2))(x)
x = BatchNormalization(mode=2, axis=3)(x)
# dim: (16, 16, 256)
x = Deconv2D(64, 5, 5, output_shape=(batch_size, 32, 32, 64),
activation='relu', border_mode='same', subsample=(2, 2))(x)
x = BatchNormalization(mode=2, axis=3)(x)
# dim: (32, 32, 64)
x = Deconv2D(3, 5, 5, output_shape=(batch_size, 32, 32, 3),
activation='linear', border_mode='same', subsample=(1, 1))(x)
decoded = BatchNormalization(mode=2, axis=3)(x)
print 'decoded shape:', decoded._keras_shape
autoencoder = Model(input_img, decoded)
# define vae loss
def vae_loss(y, y_pred):
# TODO: generalize this function
recon_loss = K.sum(K.square(y_pred - y), axis=[1, 2, 3])
kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var),
axis=[1, 2, 3])
print ('pre average loss shape:',
recon_loss.get_shape().as_list(),
kl_loss.get_shape().as_list())
return K.mean(recon_loss + kl_loss)
return autoencoder, vae_loss
评论列表
文章目录