def group_layer(self, group_num, filters, name, kernel_regularizer_l2):
def f(input):
if group_num == 1:
tower = Conv2D(filters, (1, 1), name=name + '_conv2d_0_1', padding='same',
kernel_initializer=IdentityConv())(input)
tower = Conv2D(filters, (3, 3), name=name + '_conv2d_0_2', padding='same',
kernel_initializer=IdentityConv(),
kernel_regularizer=regularizers.l2(kernel_regularizer_l2))(tower)
tower = PReLU()(tower)
return tower
else:
group_output = []
for i in range(group_num):
filter_num = filters / group_num
# if filters = 201, group_num = 4, make sure last group filters num = 51
if i == group_num - 1: # last group
filter_num = filters - i * (filters / group_num)
tower = Conv2D(filter_num, (1, 1), name=name + '_conv2d_' + str(i) + '_1', padding='same',
kernel_initializer=GroupIdentityConv(i, group_num))(input)
tower = Conv2D(filter_num, (3, 3), name=name + '_conv2d_' + str(i) + '_2', padding='same',
kernel_initializer=IdentityConv(),
kernel_regularizer=regularizers.l2(kernel_regularizer_l2))(tower)
tower = PReLU()(tower)
group_output.append(tower)
if K.image_data_format() == 'channels_first':
axis = 1
elif K.image_data_format() == 'channels_last':
axis = 3
output = Concatenate(axis=axis)(group_output)
return output
return f
评论列表
文章目录