decoder.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:enet-keras 作者: PavlosMelissinos 项目源码 文件源码
def bottleneck(encoder, output, upsample=False, reverse_module=False):
    internal = output // 4

    x = Conv2D(internal, (1, 1), use_bias=False)(encoder)
    x = BatchNormalization(momentum=0.1)(x)
    # x = Activation('relu')(x)
    x = PReLU(shared_axes=[1, 2])(x)
    if not upsample:
        x = Conv2D(internal, (3, 3), padding='same', use_bias=True)(x)
    else:
        x = Conv2DTranspose(filters=internal, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
    x = BatchNormalization(momentum=0.1)(x)
    # x = Activation('relu')(x)
    x = PReLU(shared_axes=[1, 2])(x)

    x = Conv2D(output, (1, 1), padding='same', use_bias=False)(x)

    other = encoder
    if encoder.get_shape()[-1] != output or upsample:
        other = Conv2D(output, (1, 1), padding='same', use_bias=False)(other)
        other = BatchNormalization(momentum=0.1)(other)
        if upsample and reverse_module is not False:
            other = MaxUnpooling2D()([other, reverse_module])

    if upsample and reverse_module is False:
        decoder = x
    else:
        x = BatchNormalization(momentum=0.1)(x)
        decoder = add([x, other])
        # decoder = Activation('relu')(decoder)
        decoder = PReLU(shared_axes=[1, 2])(decoder)

    return decoder
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号