resnet.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:convnet-study 作者: robertomest 项目源码 文件源码
def resnet_model(nb_blocks, bottleneck=True, l2_reg=1e-4):
    nb_channels = [16, 32, 64]
    inputs = Input((32, 32, 3))
    x = Convolution2D(16, 3, 3, border_mode='same', init='he_normal',
                      W_regularizer=l2(l2_reg), bias=False)(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    for n, f in zip(nb_channels, [True, False, False]):
        x = block_stack(x, n, nb_blocks, bottleneck=bottleneck, l2_reg=l2_reg,
                        first=f)
    # Last BN-Relu
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = GlobalAveragePooling2D()(x)
    x = Dense(10)(x)
    x = Activation('softmax')(x)

    model = Model(input=inputs, output=x)
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号