layers.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:nn_playground 作者: DingKe 项目源码 文件源码
def build(self, input_shape):
        assert len(input_shape) == 4
        self.input_spec = InputSpec(shape=input_shape)

        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = 3  
        channels = input_shape[channel_axis]

        self.kernel1 = self.add_weight(shape=(channels, channels // self.ratio),
                                      initializer=self.kernel_initializer,
                                      name='kernel1',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        if self.use_bias:
            self.bias1 = self.add_weight(shape=(channels // self.ratio,),
                                         initializer=self.bias_initializer,
                                         name='bias1',
                                         regularizer=self.bias_regularizer,
                                         constraint=self.bias_constraint)
        else:
            self.bias1 = None

        self.kernel2 = self.add_weight(shape=(channels // self.ratio, channels),
                                      initializer=self.kernel_initializer,
                                      name='kernel2',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        if self.use_bias:
            self.bias2 = self.add_weight(shape=(channels,),
                                         initializer=self.bias_initializer,
                                         name='bias2',
                                         regularizer=self.bias_regularizer,
                                         constraint=self.bias_constraint)
        else:
            self.bias2 = None

        self.built = True
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号