models.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:keras-squeezenet 作者: dvbuntu 项目源码 文件源码
def get_small_squeezenet(nb_classes):

    input_img = Input(shape=(3, 32, 32))
    x = Convolution2D(16, 3, 3, border_mode='same')(input_img)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(3, 3))(x)

    x = fire_module(x, 32, 128)
    x = fire_module(x, 32, 128)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = fire_module(x, 48, 192)
    x = fire_module(x, 48, 192)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = fire_module(x, 64, 256)
    x = Dropout(0.5)(x)

    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Convolution2D(nb_classes, 1, 1, border_mode='valid')(x)

    # global pooling not available
    x = AveragePooling2D(pool_size=(4, 4))(x)
    x = Flatten()(x)
    out = Dense(nb_classes, activation='softmax')(x)
    model = Model(input=input_img, output=[out])
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号