def wide_residual_network(img_input,classes_num,depth,k):
print('Wide-Resnet %dx%d' %(depth, k))
n_filters = [16, 16*k, 32*k, 64*k]
n_stack = (depth - 4) / 6
in_filters = 16
def conv3x3(x,filters):
return Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1), padding='same',
kernel_initializer=he_normal(),
kernel_regularizer=regularizers.l2(weight_decay))(x)
def residual_block(x,out_filters,increase_filter=False):
if increase_filter:
first_stride = (2,2)
else:
first_stride = (1,1)
pre_bn = BatchNormalization()(x)
pre_relu = Activation('relu')(pre_bn)
conv_1 = Conv2D(out_filters,kernel_size=(3,3),strides=first_stride,padding='same',kernel_initializer=he_normal(),kernel_regularizer=regularizers.l2(weight_decay))(pre_relu)
bn_1 = BatchNormalization()(conv_1)
relu1 = Activation('relu')(bn_1)
conv_2 = Conv2D(out_filters, kernel_size=(3,3), strides=(1,1), padding='same', kernel_initializer=he_normal(),kernel_regularizer=regularizers.l2(weight_decay))(relu1)
if increase_filter or in_filters != out_filters:
projection = Conv2D(out_filters,kernel_size=(1,1),strides=first_stride,padding='same',kernel_initializer=he_normal(),kernel_regularizer=regularizers.l2(weight_decay))(x)
block = add([conv_2, projection])
else:
block = add([conv_2,x])
return block
def wide_residual_layer(x,out_filters,increase_filter=False):
x = residual_block(x,out_filters,increase_filter)
in_filters = out_filters
for _ in range(1,int(n_stack)):
x = residual_block(x,out_filters)
return x
x = conv3x3(img_input,n_filters[0])
x = wide_residual_layer(x,n_filters[1])
x = wide_residual_layer(x,n_filters[2],increase_filter=True)
x = wide_residual_layer(x,n_filters[3],increase_filter=True)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = GlobalAveragePooling2D()(x)
x = Dense(classes_num,activation='softmax',kernel_initializer=he_normal(),kernel_regularizer=regularizers.l2(weight_decay))(x)
return x
评论列表
文章目录