def __init__(self, size, comm, decay=0.9, eps=2e-5, dtype=numpy.float32,
use_gamma=True, use_beta=True,
initial_gamma=None, initial_beta=None):
chainer.utils.experimental(
'chainermn.links.MultiNodeBatchNormalization')
if chainer.__version__.startswith('1.'):
raise RuntimeError(
'MultiNodeBatchNormalization works only with '
'chainer >= 2.0.0.')
super(MultiNodeBatchNormalization, self).__init__()
self.comm = comm
self.avg_mean = numpy.zeros(size, dtype=dtype)
self.register_persistent('avg_mean')
self.avg_var = numpy.zeros(size, dtype=dtype)
self.register_persistent('avg_var')
self.N = 0
self.register_persistent('N')
self.decay = decay
self.eps = eps
with self.init_scope():
if use_gamma:
if initial_gamma is None:
initial_gamma = 1
initial_gamma = initializers._get_initializer(initial_gamma)
initial_gamma.dtype = dtype
self.gamma = variable.Parameter(initial_gamma, size)
if use_beta:
if initial_beta is None:
initial_beta = 0
initial_beta = initializers._get_initializer(initial_beta)
initial_beta.dtype = dtype
self.beta = variable.Parameter(initial_beta, size)
评论列表
文章目录