def __grouped_convolution_block(input, grouped_channels, cardinality, strides, weight_decay=5e-4):
''' Adds a grouped convolution block. It is an equivalent block from the paper
Args:
input: input tensor
grouped_channels: grouped number of filters
cardinality: cardinality factor describing the number of groups
strides: performs strided convolution for downscaling if > 1
weight_decay: weight decay term
Returns: a keras tensor
'''
init = input
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
group_list = []
if cardinality == 1:
# with cardinality 1, it is a standard convolution
x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides),
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(init)
x = BatchNormalization(axis=channel_axis)(x)
x = LeakyReLU()(x)
return x
for c in range(cardinality):
x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels]
if K.image_data_format() == 'channels_last' else
lambda z: z[:, c * grouped_channels:(c + 1) * grouped_channels, :, :])(input)
x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides),
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(x)
group_list.append(x)
group_merge = concatenate(group_list, axis=channel_axis)
x = BatchNormalization(axis=channel_axis)(group_merge)
x = LeakyReLU()(x)
return x
评论列表
文章目录