def __call__(self, x, finetune=False):
if hasattr(self, 'gamma'):
gamma = self.gamma
else:
with cuda.get_device_from_id(self._device_id):
gamma = variable.Variable(self.xp.ones(
self.avg_mean.shape, dtype=x.dtype))
if hasattr(self, 'beta'):
beta = self.beta
else:
with cuda.get_device_from_id(self._device_id):
beta = variable.Variable(self.xp.zeros(
self.avg_mean.shape, dtype=x.dtype))
if chainer.configuration.config.train:
if finetune:
self.N += 1
decay = 1. - 1. / self.N
else:
decay = self.decay
func = MultiNodeBatchNormalizationFunction(
self.comm, self.eps, self.avg_mean, self.avg_var, decay)
ret = func(x, gamma, beta)
self.avg_mean[:] = func.running_mean
self.avg_var[:] = func.running_var
else:
# Use running average statistics or fine-tuned statistics.
mean = variable.Variable(self.avg_mean)
var = variable.Variable(self.avg_var)
ret = batch_normalization.fixed_batch_normalization(
x, gamma, beta, mean, var, self.eps)
return ret
评论列表
文章目录