def _add_auxiliary_head(x, classes, weight_decay):
'''Adds an auxiliary head for training the model
From section A.7 "Training of ImageNet models" of the paper, all NASNet models are
trained using an auxiliary classifier around 2/3 of the depth of the network, with
a loss weight of 0.4
# Arguments
x: input tensor
classes: number of output classes
weight_decay: l2 regularization weight
# Returns
a keras Tensor
'''
img_height = 1 if K.image_data_format() == 'channels_last' else 2
img_width = 2 if K.image_data_format() == 'channels_last' else 3
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
with K.name_scope('auxiliary_branch'):
auxiliary_x = Activation('relu')(x)
auxiliary_x = AveragePooling2D((5, 5), strides=(3, 3), padding='valid', name='aux_pool')(auxiliary_x)
auxiliary_x = Conv2D(128, (1, 1), padding='same', use_bias=False, name='aux_conv_projection',
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(auxiliary_x)
auxiliary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
name='aux_bn_projection')(auxiliary_x)
auxiliary_x = Activation('relu')(auxiliary_x)
auxiliary_x = Conv2D(768, (auxiliary_x._keras_shape[img_height], auxiliary_x._keras_shape[img_width]),
padding='valid', use_bias=False, kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay), name='aux_conv_reduction')(auxiliary_x)
auxiliary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
name='aux_bn_reduction')(auxiliary_x)
auxiliary_x = Activation('relu')(auxiliary_x)
auxiliary_x = GlobalAveragePooling2D()(auxiliary_x)
auxiliary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay),
name='aux_predictions')(auxiliary_x)
return auxiliary_x
评论列表
文章目录