model.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:DeepWorks 作者: daigo0927 项目源码 文件源码
def build(input_shape, num_outputs,
              block_fn, repetitions):

        inputs = Input(shape = input_shape)
        conv1 = Conv2D(64, (7, 7), strides = (2, 2),
                       padding = 'same')(inputs)
        conv1 = BatchNormalization()(conv1)
        conv1 = Activation('relu')(conv1)
        pool1 = MaxPooling2D(pool_size = (3, 3), strides = (2, 2),
                            padding = 'same')(conv1)

        x = pool1
        filters = 64
        first_layer = True
        for i, r in enumerate(repetitions):
            x = _residual_block(block_fn, filters = filters,
                                repetitions = r, is_first_layer = first_layer)(x)
            filters *= 2
            if first_layer:
                first_layer = False

        # last activation <- unnecessary???
        # x = BatchNormalization()(x)
        # x = Activation('relu')(x)

        _, w, h, ch = K.int_shape(x)
        pool2 = AveragePooling2D(pool_size = (w, h), strides = (1, 1))(x)
        flat1 = Flatten()(pool2)
        outputs = Dense(num_outputs, kernel_initializer = init,
                        activation = 'softmax')(flat1)

        model = Model(inputs = inputs, outputs = outputs)
        return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号