models.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:lsun-room 作者: leVirve 项目源码 文件源码
def fcn_Resnet50(input_shape = None, weight_decay=0.0002, batch_momentum=0.9, batch_shape=None, classes=40):

    img_input = Input(shape=input_shape)
    bn_axis = 3

    x = Conv2D(64, kernel_size=(7,7), subsample=(2, 2), border_mode='same', name='conv1', W_regularizer=l2(weight_decay))(img_input)
    x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv_block(3, [64, 64, 256], stage=2, block='a', strides=(1, 1))(x)
    x = identity_block(3, [64, 64, 256], stage=2, block='b')(x)
    x = identity_block(3, [64, 64, 256], stage=2, block='c')(x)

    x = conv_block(3, [128, 128, 512], stage=3, block='a')(x)
    x = identity_block(3, [128, 128, 512], stage=3, block='b')(x)
    x = identity_block(3, [128, 128, 512], stage=3, block='c')(x)
    x = identity_block(3, [128, 128, 512], stage=3, block='d')(x)

    x = conv_block(3, [256, 256, 1024], stage=4, block='a')(x)
    x = identity_block(3, [256, 256, 1024], stage=4, block='b')(x)
    x = identity_block(3, [256, 256, 1024], stage=4, block='c')(x)
    x = identity_block(3, [256, 256, 1024], stage=4, block='d')(x)
    x = identity_block(3, [256, 256, 1024], stage=4, block='e')(x)
    x = identity_block(3, [256, 256, 1024], stage=4, block='f')(x)

    x = conv_block(3, [512, 512, 2048], stage=5, block='a')(x)
    x = identity_block(3, [512, 512, 2048], stage=5, block='b')(x)
    x = identity_block(3, [512, 512, 2048], stage=5, block='c')(x)
    #classifying layer
    x = Conv2D(filters=40, kernel_size=(1,1), strides=(1,1), init='he_normal', activation='linear', border_mode='valid', W_regularizer=l2(weight_decay))(x)

    x = Conv2DTranspose(filters=40, kernel_initializer='he_normal', kernel_size=(64, 64), strides=(32, 32), padding='valid',use_bias=False, name='upscore2')(x)
    x = Cropping2D(cropping=((19, 36),(19, 29)), name='score')(x)

    model = Model(img_input, x)
    weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5', RES_WEIGHTS_PATH_NO_TOP, cache_subdir='models')
    model.load_weights(weights_path, by_name=True)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号