res_auto.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:dem 作者: hengyuan-hu 项目源码 文件源码
def resnet_cifar10(repetations, input_shape):
    x = Input(shape=input_shape)
    conv1 = Convolution2D(16, 3, 3, init='he_normal', border_mode='same',
                          W_regularizer=l2(1e-4))(x)
    # feature map size (32, 32, 16)

    # Build residual blocks..
    block_fn = _basic_block
    block1 = _residual_block(block_fn, 16, repetations, (1, 1))(conv1)
    # feature map size (16, 16)
    block2 = _residual_block(block_fn, 32, repetations, (2, 2))(block1)
    # feature map size (8, 8)
    block3 = _residual_block(block_fn, 64, repetations, (2, 2))(block2)

    post_block_norm = BatchNormalization(mode=2, axis=3)(block3)
    post_blob_relu = Activation("relu")(post_block_norm)

    # Classifier block
    pool2 = GlobalAveragePooling2D()(post_blob_relu)
    dense = Dense(output_dim=10, init="he_normal",
                  W_regularizer=l2(1e-4), activation="softmax")(pool2)

    model = Model(input=x, output=dense)
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号