def get_output_for(self, input, deterministic=False, **kwargs):
input_mean = input.mean(self.axes)
input_var = input.var(self.axes)
# Decide whether to use the stored averages or mini-batch statistics
use_averages = kwargs.get('batch_norm_use_averages',
deterministic)
if use_averages:
mean = self.mean
var = self.var
else:
mean = input_mean
var = input_var
# Decide whether to update the stored averages
update_averages = kwargs.get('batch_norm_update_averages',
not deterministic)
if update_averages:
# Trick: To update the stored statistics, we create memory-aliased
# clones of the stored statistics:
running_mean = theano.clone(self.mean, share_inputs=False)
running_var = theano.clone(self.var, share_inputs=False)
# set a default update for them:
running_mean.default_update = ((1 - self.alpha) * running_mean +
self.alpha * input_mean)
running_var.default_update = ((1 - self.alpha) * running_var +
self.alpha * input_var)
# and make sure they end up in the graph without participating in
# the computation (this way their default_update will be collected
# and applied, but the computation will be optimized away):
mean += 0 * running_mean
var += 0 * running_var
# prepare dimshuffle pattern inserting broadcastable axes as needed
param_axes = iter(range(self.beta.ndim))
pattern = ['x' if input_axis in self.axes
else next(param_axes)
for input_axis in range(input.ndim)]
# apply dimshuffle pattern to all parameters
beta = self.beta.dimshuffle(pattern)
gamma = self.gamma.dimshuffle(pattern)
mean = mean.dimshuffle(pattern)
std = T.sqrt(var + self.epsilon)
std = std.dimshuffle(pattern)
# normalize
# normalized = (input - mean) * (gamma / std) + beta
normalized = T.nnet.batch_normalization(input, gamma=gamma, beta=beta,
mean=mean, std=std,
mode=self.mode)
return self.nonlinearity(normalized)
评论列表
文章目录