def deep_decoder2(input_shape):
encoded = Input(shape=input_shape)
print 'encoded shape:', encoded.get_shape().as_list()
x = encoded
# x = BatchNormalization(mode=2, axis=3)(encoded)
# batch_size, h, w, _ = tf.shape(x)
batch_size = tf.shape(x)[0]
# dim: (1, 1, 512)
x = Deconv2D(512, 4, 4, output_shape=[batch_size, 4, 4, 512],
activation='relu', border_mode='same', subsample=(4, 4))(encoded)
x = BatchNormalization(mode=2, axis=3)(x)
# (4, 4, 512)
x = Deconv2D(256, 5, 5, output_shape=[batch_size, 8, 8, 256],
activation='relu', border_mode='same', subsample=(2, 2))(x)
x = BatchNormalization(mode=2, axis=3)(x)
# dim: (8, 8, 236)
# h *= 2; w *= 2
x = Deconv2D(128, 5, 5, output_shape=(batch_size, 16, 16, 128),
activation='relu', border_mode='same', subsample=(2, 2))(x)
x = BatchNormalization(mode=2, axis=3)(x)
# dim: (16, 16, 256)
x = Deconv2D(64, 5, 5, output_shape=(batch_size, 32, 32, 64),
activation='relu', border_mode='same', subsample=(2, 2))(x)
x = BatchNormalization(mode=2, axis=3)(x)
# dim: (32, 32, 64)
x = Deconv2D(3, 5, 5, output_shape=(batch_size, 32, 32, 3),
activation='linear', border_mode='same', subsample=(1, 1))(x)
decoded = BatchNormalization(mode=2, axis=3)(x)
return Model(encoded, decoded)
评论列表
文章目录