model_arch.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:deep-instrument-heroku 作者: bzamecnik 项目源码 文件源码
def create_model(input_shape, class_count):
    inputs = Input(shape=input_shape)

    # add one more dimension for convolution
    x = Reshape(input_shape + (1, ))(inputs)

    x = BatchNormalization()(x)

    def convolution_block(filter_count, dropout):
        def create(x):
            x = Convolution2D(filter_count, 3, 3, border_mode='same')(x)
            x = BatchNormalization()(x)
            x = ELU()(x)
            x = Convolution2D(filter_count, 3, 3, border_mode='same')(x)
            x = BatchNormalization()(x)
            x = ELU()(x)
            x = MaxPooling2D(pool_size=(2, 2))(x)
            x = Dropout(dropout)(x)
            return x
        return create

    x = convolution_block(filter_count=32, dropout=0.1)(x)
    x = convolution_block(filter_count=64, dropout=0.1)(x)
    x = convolution_block(filter_count=64, dropout=0.1)(x)
    x = convolution_block(filter_count=64, dropout=0.1)(x)

    x = Flatten()(x)

    x = Dense(class_count)(x)
    x = BatchNormalization()(x)
    predictions = Activation('softmax')(x)

    model = Model(inputs, predictions)

    model.compile(loss='categorical_crossentropy', optimizer='adam',
        metrics=['accuracy'])

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号