nasnet.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:keras-contrib 作者: farizrahman4u 项目源码 文件源码
def _add_auxiliary_head(x, classes, weight_decay):
    '''Adds an auxiliary head for training the model

    From section A.7 "Training of ImageNet models" of the paper, all NASNet models are
    trained using an auxiliary classifier around 2/3 of the depth of the network, with
    a loss weight of 0.4

    # Arguments
        x: input tensor
        classes: number of output classes
        weight_decay: l2 regularization weight

    # Returns
        a keras Tensor
    '''
    img_height = 1 if K.image_data_format() == 'channels_last' else 2
    img_width = 2 if K.image_data_format() == 'channels_last' else 3
    channel_axis = 1 if K.image_data_format() == 'channels_first' else -1

    with K.name_scope('auxiliary_branch'):
        auxiliary_x = Activation('relu')(x)
        auxiliary_x = AveragePooling2D((5, 5), strides=(3, 3), padding='valid', name='aux_pool')(auxiliary_x)
        auxiliary_x = Conv2D(128, (1, 1), padding='same', use_bias=False, name='aux_conv_projection',
                             kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(auxiliary_x)
        auxiliary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
                                         name='aux_bn_projection')(auxiliary_x)
        auxiliary_x = Activation('relu')(auxiliary_x)

        auxiliary_x = Conv2D(768, (auxiliary_x._keras_shape[img_height], auxiliary_x._keras_shape[img_width]),
                             padding='valid', use_bias=False, kernel_initializer='he_normal',
                             kernel_regularizer=l2(weight_decay), name='aux_conv_reduction')(auxiliary_x)
        auxiliary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
                                         name='aux_bn_reduction')(auxiliary_x)
        auxiliary_x = Activation('relu')(auxiliary_x)

        auxiliary_x = GlobalAveragePooling2D()(auxiliary_x)
        auxiliary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay),
                            name='aux_predictions')(auxiliary_x)
    return auxiliary_x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号