archs.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:kaggle-dstl-satellite-imagery-feature-detection 作者: alno 项目源码 文件源码
def rnet1_mi(input_shapes, n_classes):
    def conv(size, x):
        x = Convolution2D(size, 3, 3, border_mode='same', init='he_normal', bias=False)(x)
        x = BatchNormalization(axis=1, mode=0)(x)
        x = PReLU(shared_axes=[2, 3])(x)
        return x

    def unet_block(sizes, inp):
        x = inp

        skips = []

        for sz in sizes[:-1]:
            x = conv(sz, x)
            skips.append(x)
            x = MaxPooling2D((2, 2))(x)

        x = conv(sizes[-1], x)

        for sz in reversed(sizes[:-1]):
            x = conv(sz, merge([UpSampling2D((2, 2))(x), skips.pop()], mode='concat', concat_axis=1))

        return x

    def radd(out, inp, block):
        block_in = merge([inp, out], mode='concat', concat_axis=1)
        block_out = block(block_in)

        return merge([block_out, out], mode='sum')

    in_I = Input(input_shapes['in_I'], name='in_I')
    in_M = Input(input_shapes['in_M'], name='in_M')

    # Build piramid of inputs
    inp0 = in_I
    inp1 = AveragePooling2D((2, 2))(inp0)
    inp2 = merge([AveragePooling2D((2, 2))(inp1), in_M], mode='concat', concat_axis=1)
    inp3 = AveragePooling2D((2, 2))(inp2)

    # Build outputs in resnet fashion
    out3 = unet_block([32, 48], inp3)

    out2 = UpSampling2D((2, 2))(out3)
    out2 = radd(out2, inp2, lambda x: unet_block([32, 48], x))

    out1 = UpSampling2D((2, 2))(out2)
    out1 = radd(out1, inp1, lambda x: unet_block([32, 48], x))
    out1 = radd(out1, inp1, lambda x: unet_block([32, 48, 64], x))

    out0 = UpSampling2D((2, 2))(out1)
    out0 = radd(out0, inp0, lambda x: unet_block([32, 48], x))
    out0 = radd(out0, inp0, lambda x: unet_block([32, 48, 64], x))

    # Final convolution
    out = Convolution2D(n_classes, 1, 1, activation='sigmoid')(out0)

    return Model(input=[in_I, in_M], output=out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号