trans.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:keras-squeezenet 作者: dvbuntu 项目源码 文件源码
def get_squeezenet(nb_classes, img_size = (64,64)):

    input_img = Input(shape=(3, img_size[0], img_size[1]))
    x = Convolution2D(96, 7, 7, subsample=(2, 2), border_mode='valid')(input_img)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2))(x)

    x = fire_module(x, 16, 64)
    x = fire_module(x, 16, 64)
    x = fire_module(x, 32, 128)
    x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2))(x)

    x = fire_module(x, 32, 192)
    x = fire_module(x, 48, 192)
    x = fire_module(x, 48, 192)
    x = fire_module(x, 64, 256)
    x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2))(x)

    x = fire_module(x, 64, 256)
    x = Dropout(0.5)(x)

    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Convolution2D(nb_classes, 1, 1, border_mode='valid')(x)

    # global pooling not available
    x = GlobalAveragePooling2D()(x)
    out = Dense(nb_classes, activation='softmax')(x)
    model = Model(input=input_img, output=[out])
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号