def __call__(self, x, gamma_=None, beta_=None):
if hasattr(self, 'gamma'):
gamma = self.gamma
elif gamma_ is not None:
gamma = gamma_
else:
with cuda.get_device_from_id(self._device_id):
gamma = variable.Variable(self.xp.ones(
self.avg_mean.shape, dtype=x.dtype))
if hasattr(self, 'beta'):
beta = self.beta
elif beta_ is not None:
beta = beta_
else:
with cuda.get_device_from_id(self._device_id):
beta = variable.Variable(self.xp.zeros(
self.avg_mean.shape, dtype=x.dtype))
decay = self.decay
if (not configuration.config.train) and self.valid_test:
mean = variable.Variable(self.avg_mean)
var = variable.Variable(self.avg_var)
ret = fixed_instance_normalization(
x, gamma, beta, mean, var, self.eps)
else:
func = InstanceNormalizationFunction(
self.eps, self.avg_mean, self.avg_var, decay)
ret = func(x, gamma, beta)
self.avg_mean = func.running_mean
self.avg_var = func.running_var
return ret
评论列表
文章目录