unet.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:lsun_2017 作者: ternaus 项目源码 文件源码
def get_unet0(num_start_filters=32):
    inputs = Input((img_rows, img_cols, num_channels))
    conv1 = ConvBN2(inputs, num_start_filters)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = ConvBN2(pool1, 2 * num_start_filters)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = ConvBN2(pool2, 4 * num_start_filters)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = ConvBN2(pool3, 8 * num_start_filters)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = ConvBN2(pool4, 16 * num_start_filters)

    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4])
    conv6 = ConvBN2(up6, 8 * num_start_filters)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3])
    conv7 = ConvBN2(up7, 4 * num_start_filters)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2])
    conv8 = ConvBN2(up8, 2 * num_start_filters)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1])
    conv9 = Conv2D(num_start_filters, (3, 3), padding="same", kernel_initializer="he_uniform")(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('selu')(conv9)
    conv9 = Conv2D(num_start_filters, (3, 3), padding="same", kernel_initializer="he_uniform")(conv9)
    crop9 = Cropping2D(cropping=((16, 16), (16, 16)))(conv9)
    conv9 = BatchNormalization()(crop9)
    conv9 = Activation('selu')(conv9)

    conv10 = Conv2D(num_mask_channels, (1, 1))(conv9)

    model = Model(inputs=inputs, outputs=conv10)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号