multilayer.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:EUSIPCO2017 作者: Veleslavia 项目源码 文件源码
def build_model(n_classes):

    if K.image_dim_ordering() == 'th':
        input_shape = (1, N_MEL_BANDS, SEGMENT_DUR)
        channel_axis = 1
    else:
        input_shape = (N_MEL_BANDS, SEGMENT_DUR, 1)
        channel_axis = 3
    melgram_input = Input(shape=input_shape)

    maxpool_const = 4
    m_sizes = [5, 80]
    n_sizes = [1, 3, 5]
    n_filters = [128, 64, 32]

    layers = list()

    for m_i in m_sizes:
        for i, n_i in enumerate(n_sizes):
            x = Convolution2D(n_filters[i], m_i, n_i,
                              border_mode='same',
                              init='he_normal',
                              W_regularizer=l2(1e-5),
                              name=str(n_i)+'_'+str(m_i)+'_'+'conv')(melgram_input)
            x = BatchNormalization(axis=channel_axis, mode=0, name=str(n_i)+'_'+str(m_i)+'_'+'bn')(x)
            x = ELU()(x)
            x = MaxPooling2D(pool_size=(N_MEL_BANDS/maxpool_const, SEGMENT_DUR/maxpool_const),
                             name=str(n_i)+'_'+str(m_i)+'_'+'pool')(x)
            layers.append(x)

    x = merge(layers, mode='concat', concat_axis=channel_axis)

    x = Dropout(0.25)(x)
    x = Convolution2D(128, 3, 3, init='he_normal', W_regularizer=l2(1e-5), border_mode='same', name='conv2')(x)
    x = BatchNormalization(axis=channel_axis, mode=0, name='bn2')(x)
    x = ELU()(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool2')(x)

    x = Dropout(0.25)(x)
    x = Convolution2D(128, 3, 3, init='he_normal', W_regularizer=l2(1e-5), border_mode='same', name='conv3')(x)
    x = BatchNormalization(axis=channel_axis, mode=0, name='bn3')(x)
    x = ELU()(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool3')(x)

    x = Flatten(name='flatten')(x)
    x = Dropout(0.5)(x)
    x = Dense(256, init='he_normal', W_regularizer=l2(1e-5), name='fc1')(x)
    x = ELU()(x)
    x = Dropout(0.5)(x)
    x = Dense(n_classes, init='he_normal', W_regularizer=l2(1e-5), activation='softmax', name='prediction')(x)
    model = Model(melgram_input, x)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号