def build(self, input_shape):
# This currently only works for 4D inputs: assuming (B, H, W, C)
self.input_spec = [InputSpec(shape=input_shape)]
shape = (1, 1, 1, input_shape[-1])
self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name))
self.beta = self.beta_init(shape, name='{}_beta'.format(self.name))
self.trainable_weights = [self.gamma, self.beta]
self.built = True
评论列表
文章目录