def _forward(self):
eps = self.eps
param_size = (1, 1, self.n_output, 1, 1)
self.gamma = self.declare(param_size)
self.beta = self.declare(param_size)
mean = self.inpt.mean(axis=[0, 1, 3, 4], keepdims=False)
std = self.inpt.std(axis=[0, 1, 3, 4], keepdims=False)
self._setup_running_metrics(self.n_output)
self.running_mean.default_update = ifelse(
self.training,
(1.0 - self.alpha) * self.running_mean + self.alpha * mean,
self.running_mean
)
self.running_std.default_update = ifelse(
self.training,
(1.0 - self.alpha) * self.running_std + self.alpha * std,
self.running_std
)
# This will be optimized away, but ensures the running mean and the running std get updated.
# Reference: https://gist.github.com/f0k/f1a6bd3c8585c400c190#file-batch_norm-py-L86
mean += 0 * self.running_mean
std += 0 * self.running_std
use_mean = ifelse(self.training, mean, self.running_mean)
use_std = ifelse(self.training, std, self.running_std)
use_mean = use_mean.dimshuffle('x', 'x', 0, 'x', 'x')
use_std = use_std.dimshuffle('x', 'x', 0, 'x', 'x')
norm_inpt = (self.inpt - use_mean) / (use_std + eps)
self.output = self.gamma * norm_inpt + self.beta
评论列表
文章目录