def residual_network(img_input,classes_num=10,stack_n=5):
def residual_block(intput,out_channel,increase=False):
if increase:
stride = (2,2)
else:
stride = (1,1)
pre_bn = BatchNormalization()(intput)
pre_relu = Activation('relu')(pre_bn)
conv_1 = Conv2D(out_channel,kernel_size=(3,3),strides=stride,padding='same',
kernel_initializer="he_normal",
kernel_regularizer=regularizers.l2(weight_decay))(pre_relu)
bn_1 = BatchNormalization()(conv_1)
relu1 = Activation('relu')(bn_1)
conv_2 = Conv2D(out_channel,kernel_size=(3,3),strides=(1,1),padding='same',
kernel_initializer="he_normal",
kernel_regularizer=regularizers.l2(weight_decay))(relu1)
if increase:
projection = Conv2D(out_channel,
kernel_size=(1,1),
strides=(2,2),
padding='same',
kernel_initializer="he_normal",
kernel_regularizer=regularizers.l2(weight_decay))(intput)
block = add([conv_2, projection])
else:
block = add([intput,conv_2])
return block
# build model
# total layers = stack_n * 3 * 2 + 2
# stack_n = 5 by default, total layers = 32
# input: 32x32x3 output: 32x32x16
x = Conv2D(filters=16,kernel_size=(3,3),strides=(1,1),padding='same',
kernel_initializer="he_normal",
kernel_regularizer=regularizers.l2(weight_decay))(img_input)
# input: 32x32x16 output: 32x32x16
for _ in range(stack_n):
x = residual_block(x,16,False)
# input: 32x32x16 output: 16x16x32
x = residual_block(x,32,True)
for _ in range(1,stack_n):
x = residual_block(x,32,False)
# input: 16x16x32 output: 8x8x64
x = residual_block(x,64,True)
for _ in range(1,stack_n):
x = residual_block(x,64,False)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = GlobalAveragePooling2D()(x)
# input: 64 output: 10
x = Dense(classes_num,activation='softmax',
kernel_initializer="he_normal",
kernel_regularizer=regularizers.l2(weight_decay))(x)
return x
评论列表
文章目录