def build(self, input_shape):
assert len(input_shape) == 4
self.input_spec = InputSpec(shape=input_shape)
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = 3
channels = input_shape[channel_axis]
self.kernel1 = self.add_weight(shape=(channels, channels // self.ratio),
initializer=self.kernel_initializer,
name='kernel1',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias1 = self.add_weight(shape=(channels // self.ratio,),
initializer=self.bias_initializer,
name='bias1',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias1 = None
self.kernel2 = self.add_weight(shape=(channels // self.ratio, channels),
initializer=self.kernel_initializer,
name='kernel2',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias2 = self.add_weight(shape=(channels,),
initializer=self.bias_initializer,
name='bias2',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias2 = None
self.built = True
评论列表
文章目录