def get_output_for(self, input, deterministic=False, **kwargs):
input_mean = input.mean(self.axes)
input_inv_std = T.inv(T.sqrt(input.var(self.axes) + self.epsilon))
# Decide whether to use the stored averages or mini-batch statistics
use_averages = kwargs.get('batch_norm_use_averages',
deterministic)
if use_averages:
mean = self.mean
inv_std = self.inv_std
else:
mean = input_mean
inv_std = input_inv_std
# Decide whether to update the stored averages
update_averages = kwargs.get('batch_norm_update_averages',
not deterministic)
if update_averages:
# Trick: To update the stored statistics, we create memory-aliased
# clones of the stored statistics:
running_mean = theano.clone(self.mean, share_inputs=False)
running_inv_std = theano.clone(self.inv_std, share_inputs=False)
# set a default update for them:
running_mean.default_update = ((1 - self.alpha) * running_mean +
self.alpha * input_mean)
running_inv_std.default_update = ((1 - self.alpha) *
running_inv_std +
self.alpha * input_inv_std)
# and make sure they end up in the graph without participating in
# the computation (this way their default_update will be collected
# and applied, but the computation will be optimized away):
mean += 0 * running_mean
inv_std += 0 * running_inv_std
# prepare dimshuffle pattern inserting broadcastable axes as needed
param_axes = iter(range(input.ndim - len(self.axes)))
pattern = ['x' if input_axis in self.axes
else next(param_axes)
for input_axis in range(input.ndim)]
# apply dimshuffle pattern to all parameters
beta = 0 if self.beta is None else self.beta.dimshuffle(pattern)
gamma = 1 if self.gamma is None else self.gamma.dimshuffle(pattern)
mean = mean.dimshuffle(pattern)
inv_std = inv_std.dimshuffle(pattern)
# normalize
normalized = (input - mean) * (gamma * inv_std) + beta
return normalized
评论列表
文章目录