def _apply(self, X, noise=0):
ndim = X.get_shape().ndims
# if is training, normalize input by its own mean and std
mean, var = tf.nn.moments(X, axes=self.axes)
# prepare dimshuffle pattern inserting broadcastable axes as needed
param_axes = iter(range(ndim - len(self.axes)))
pattern = ['x' if input_axis in self.axes else next(param_axes)
for input_axis in range(ndim)]
# apply dimshuffle pattern to all parameters
beta = 0 if self.beta_init is None else \
K.dimshuffle(self.get('beta'), pattern)
gamma = 1 if self.gamma_init is None else \
K.dimshuffle(self.get('gamma'), pattern)
# ====== if trainign: use local mean and var ====== #
def training_fn():
running_mean = ((1 - self.alpha) * self.get('mean') +
self.alpha * mean)
running_var = ((1 - self.alpha) * self.get('var') +
self.alpha * var)
with tf.control_dependencies([
tf.assign(self.get('mean'), running_mean),
tf.assign(self.get('var'), running_var)]):
return tf.identity(mean), tf.identity(var)
# ====== if inference: use global mean and var ====== #
def infer_fn():
return self.get('mean'), self.get('var')
mean, var = tf.cond(K.is_training(), training_fn, infer_fn)
inv_std = tf.rsqrt(var + self.epsilon)
normalized = (X - K.dimshuffle(mean, pattern)) * \
(gamma * K.dimshuffle(inv_std, pattern))
# ====== applying noise if required ====== #
if self.noise_level is not None:
normalized = K.rand.apply_noise(normalized,
level=self.noise_level, noise_dims=self.noise_dims,
noise_type='gaussian')
# add beta
normalized = normalized + beta
# activated output
return self.activation(normalized)
评论列表
文章目录