def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
shape = (input_shape[self.axis],)
self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name))
self.beta = self.beta_init(shape, name='{}_beta'.format(self.name))
self.trainable_weights = [self.gamma, self.beta]
self.running_mean = K.zeros(shape,
name='{}_running_mean'.format(self.name))
self.running_std = K.ones(shape,
name='{}_running_std'.format(self.name))
self.non_trainable_weights = [self.running_mean, self.running_std]
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
self.called_with = None
KerasBatchNormalization.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录