keras_auto_encoder.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:dem 作者: hengyuan-hu 项目源码 文件源码
def relu_deep_model1(input_shape, relu_max):
    input_img = Input(shape=input_shape)
    print 'input shape:', input_img._keras_shape
    # 32, 32
    x = Conv2D(64, 3, 3, activation='relu', border_mode='same', subsample=(2, 2))(input_img)
    x = BatchNormalization(mode=2, axis=3)(x)
    # 16, 16
    x = Conv2D(128, 3, 3, activation='relu', border_mode='same', subsample=(2, 2))(x)
    x = BatchNormalization(mode=2, axis=3)(x)
    # 8, 8
    x = Conv2D(256, 3, 3, activation='relu', border_mode='same', subsample=(2, 2))(x)
    x = BatchNormalization(mode=2, axis=3)(x)
    # 4, 4
    # latent_dim = (1, 1, 1024)
    x = Conv2D(1024, 4, 4, activation='linear',
               border_mode='same', subsample=(4, 4))(x)
    x = GaussianNoise(0.2)(x)
    # encoded = Activation('relu')(x)
    encoded = Activation(relu_n(relu_max))(x)

    print 'encoded shape:', encoded.get_shape().as_list()
    # in the origianl design, no BN as the first layer of decoder because of bug
    x = encoded
    # x = BatchNormalization(mode=2, axis=3)(encoded)

    # batch_size, h, w, _ = tf.shape(x)
    batch_size = tf.shape(x)[0]
    # dim: (1, 1, 512)
    x = Deconv2D(512, 4, 4, output_shape=[batch_size, 4, 4, 512],
                 activation='relu', border_mode='same', subsample=(4, 4))(x)
    x = BatchNormalization(mode=2, axis=3)(x)
    # (4, 4, 512)
    x = Deconv2D(256, 5, 5, output_shape=[batch_size, 8, 8, 256],
                 activation='relu', border_mode='same', subsample=(2, 2))(x)
    x = BatchNormalization(mode=2, axis=3)(x)
    # dim: (8, 8, 256)
    x = Deconv2D(128, 5, 5, output_shape=(batch_size, 16, 16, 128),
                 activation='relu', border_mode='same', subsample=(2, 2))(x)
    x = BatchNormalization(mode=2, axis=3)(x)
    # dim: (16, 16, 256)
    x = Deconv2D(64, 5, 5, output_shape=(batch_size, 32, 32, 64),
                 activation='relu', border_mode='same', subsample=(2, 2))(x)
    x = BatchNormalization(mode=2, axis=3)(x)
    # dim: (32, 32, 64)
    x = Deconv2D(3, 5, 5, output_shape=(batch_size, 32, 32, 3),
                 activation='linear', border_mode='same', subsample=(1, 1))(x)
    decoded = BatchNormalization(mode=2, axis=3)(x)
    print 'decoded shape:', decoded.get_shape().as_list()
    autoencoder = Model(input_img, decoded)
    return autoencoder
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号