def resnet_cifar10(repetations, input_shape):
x = Input(shape=input_shape)
conv1 = Convolution2D(16, 3, 3, init='he_normal', border_mode='same',
W_regularizer=l2(1e-4))(x)
# feature map size (32, 32, 16)
# Build residual blocks..
block_fn = _basic_block
block1 = _residual_block(block_fn, 16, repetations, (1, 1))(conv1)
# feature map size (16, 16)
block2 = _residual_block(block_fn, 32, repetations, (2, 2))(block1)
# feature map size (8, 8)
block3 = _residual_block(block_fn, 64, repetations, (2, 2))(block2)
post_block_norm = BatchNormalization(mode=2, axis=3)(block3)
post_blob_relu = Activation("relu")(post_block_norm)
# Classifier block
pool2 = GlobalAveragePooling2D()(post_blob_relu)
dense = Dense(output_dim=10, init="he_normal",
W_regularizer=l2(1e-4), activation="softmax")(pool2)
model = Model(input=x, output=dense)
return model
评论列表
文章目录