def call(self, x, mask=None):
if self.mode == 0 or self.mode == 2:
assert self.built, 'Layer must be built before being called'
input_shape = self.input_spec[0].shape
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis]
# case: train mode (uses stats of the current batch)
mean = K.mean(x, axis=reduction_axes)
brodcast_mean = K.reshape(mean, broadcast_shape)
std = K.mean(K.square(x - brodcast_mean) + self.epsilon, axis=reduction_axes)
std = K.sqrt(std)
brodcast_std = K.reshape(std, broadcast_shape)
mean_update = self.momentum * self.running_mean + (1 - self.momentum) * mean
std_update = self.momentum * self.running_std + (1 - self.momentum) * std
if self.mode == 2:
x_normed = (x - brodcast_mean) / (brodcast_std + self.epsilon)
out = K.reshape(self.gamma, broadcast_shape) * x_normed + K.reshape(self.beta, broadcast_shape)
else:
# mode 0
self.called_with = x
self.updates = [(self.running_mean, mean_update),
(self.running_std, std_update)]
x_normed = (x - brodcast_mean) / (brodcast_std + self.epsilon)
# case: test mode (uses running averages)
brodcast_running_mean = K.reshape(self.running_mean, broadcast_shape)
brodcast_running_std = K.reshape(self.running_std, broadcast_shape)
x_normed_running = ((x - brodcast_running_mean) / (brodcast_running_std + self.epsilon))
# pick the normalized form of x corresponding to the training phase
x_normed = K.in_train_phase(x_normed, x_normed_running)
out = K.reshape(self.gamma, broadcast_shape) * x_normed + K.reshape(self.beta, broadcast_shape)
elif self.mode == 1:
# sample-wise normalization
m = K.mean(x, axis=-1, keepdims=True)
std = K.sqrt(K.var(x, axis=-1, keepdims=True) + self.epsilon)
x_normed = (x - m) / (std + self.epsilon)
out = self.gamma * x_normed + self.beta
return out
KerasBatchNormalization.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录