models.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:enhance 作者: cdiazbas 项目源码 文件源码
def keepsize_256(nx, ny, noise, depth, activation='relu', n_filters=64, l2_reg=1e-7):
    """
    Deep residual network that keeps the size of the input throughout the whole network
    """

    def residual(inputs, n_filters):
        x = ReflectionPadding2D()(inputs)
        x = Conv2D(n_filters, (3, 3), padding='valid', kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg))(x)
        x = BatchNormalization()(x)
        x = Activation(activation)(x)
        x = ReflectionPadding2D()(x)
        x = Conv2D(n_filters, (3, 3), padding='valid', kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg))(x)
        x = BatchNormalization()(x)
        x = add([x, inputs])

        return x

    inputs = Input(shape=(nx, ny, 1))
    x = GaussianNoise(noise)(inputs)

    x = ReflectionPadding2D()(x)
    x = Conv2D(n_filters, (3, 3), padding='valid', kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg))(x)
    x0 = Activation(activation)(x)

    x = residual(x0, n_filters)

    for i in range(depth-1):
        x = residual(x, n_filters)

    x = ReflectionPadding2D()(x)
    x = Conv2D(n_filters, (3, 3), padding='valid', kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg))(x)
    x = BatchNormalization()(x)
    x = add([x, x0])

# Upsampling for superresolution
    x = UpSampling2D()(x)
    x = ReflectionPadding2D()(x)
    x = Conv2D(4*n_filters, (3, 3), padding='valid', kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg))(x)
    x = Activation(activation)(x)

    final = Conv2D(1, (1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg))(x)

    return Model(inputs=inputs, outputs=final)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号