def deep_decoder1(input_shape):
encoded = Input(shape=input_shape)
print 'decoder input shape:', encoded._keras_shape
batch_size = tf.shape(encoded)[0]
x = BatchNormalization(mode=2, axis=3)(encoded)
h, w, _ = encoded._keras_shape[1:]
x = Deconv2D(32, 1, 1, output_shape=[batch_size, h, w, 32],
activation='relu', border_mode='same')(x)
x = BatchNormalization(mode=2, axis=3)(x)
x = Deconv2D(32, 3, 3, output_shape=[batch_size, h, w, 32],
activation='relu', border_mode='same')(x)
x = BatchNormalization(mode=2, axis=3)(x)
h *= 2; w *= 2
x = Deconv2D(64, 3, 3, output_shape=(batch_size, h, w, 64),
activation='relu', border_mode='same', subsample=(2, 2))(x)
x = BatchNormalization(mode=2, axis=3)(x)
x = Deconv2D(64, 3, 3, output_shape=(batch_size, h, w, 64),
activation='relu', border_mode='same', subsample=(1, 1))(x)
x = BatchNormalization(mode=2, axis=3)(x)
h *= 2; w *= 2
x = Deconv2D(32, 3, 3, output_shape=(batch_size, h, w, 32),
activation='relu', border_mode='same', subsample=(2, 2))(x)
x = BatchNormalization(mode=2, axis=3)(x)
x = Deconv2D(32, 3, 3, output_shape=(batch_size, h, w, 32),
activation='relu', border_mode='same', subsample=(1, 1))(x)
x = BatchNormalization(mode=2, axis=3)(x)
x = Deconv2D(3, 3, 3, output_shape=(batch_size, 32, 32, 3),
activation='linear', border_mode='same', subsample=(1, 1))(x)
x = BatchNormalization(mode=2, axis=3)(x)
decoded = x
return Model(encoded, decoded)
评论列表
文章目录