def _normal_A(ip, p, filters, weight_decay=5e-5, id=None):
'''Adds a Normal cell for NASNet-A (Fig. 4 in the paper)
# Arguments:
ip: input tensor `x`
p: input tensor `p`
filters: number of output filters
weight_decay: l2 regularization weight
id: string id
# Returns:
a Keras tensor
'''
channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
with K.name_scope('normal_A_block_%s' % id):
p = _adjust_block(p, ip, filters, weight_decay, id)
h = Activation('relu')(ip)
h = Conv2D(filters, (1, 1), strides=(1, 1), padding='same', name='normal_conv_1_%s' % id,
use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(h)
h = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
name='normal_bn_1_%s' % id)(h)
with K.name_scope('block_1'):
x1_1 = _separable_conv_block(h, filters, kernel_size=(5, 5), weight_decay=weight_decay,
id='normal_left1_%s' % id)
x1_2 = _separable_conv_block(p, filters, weight_decay=weight_decay, id='normal_right1_%s' % id)
x1 = add([x1_1, x1_2], name='normal_add_1_%s' % id)
with K.name_scope('block_2'):
x2_1 = _separable_conv_block(p, filters, (5, 5), weight_decay=weight_decay, id='normal_left2_%s' % id)
x2_2 = _separable_conv_block(p, filters, (3, 3), weight_decay=weight_decay, id='normal_right2_%s' % id)
x2 = add([x2_1, x2_2], name='normal_add_2_%s' % id)
with K.name_scope('block_3'):
x3 = AveragePooling2D((3, 3), strides=(1, 1), padding='same', name='normal_left3_%s' % (id))(h)
x3 = add([x3, p], name='normal_add_3_%s' % id)
with K.name_scope('block_4'):
x4_1 = AveragePooling2D((3, 3), strides=(1, 1), padding='same', name='normal_left4_%s' % (id))(p)
x4_2 = AveragePooling2D((3, 3), strides=(1, 1), padding='same', name='normal_right4_%s' % (id))(p)
x4 = add([x4_1, x4_2], name='normal_add_4_%s' % id)
with K.name_scope('block_5'):
x5 = _separable_conv_block(h, filters, weight_decay=weight_decay, id='normal_left5_%s' % id)
x5 = add([x5, h], name='normal_add_5_%s' % id)
x = concatenate([p, x1, x2, x3, x4, x5], axis=channel_dim, name='normal_concat_%s' % id)
return x, ip
评论列表
文章目录