def __init__(self, comm, eps=2e-5, mean=None, var=None, decay=0.9):
chainer.utils.experimental(
'chainermn.functions.MultiNodeBatchNormalizationFunction')
self.comm = comm
self.running_mean = mean
self.running_var = var
# Note: cuDNN v5 requires that eps be greater than 1e-5. Otherwise, an
# error will occur.
# See CUDNN_BN_MIN_EPSILON value in cudnn.h to verify minimum allowable
# value.
self.eps = eps
if chainer.should_use_cudnn('>=auto'):
if eps < 1e-5:
msg = 'cuDNN does not allow an eps value less than 1e-5.'
raise RuntimeError(msg)
self.mean_cache = None
self.decay = decay
# We need to delay importing MPI4py (and momdules that import MPI4py)
import chainermn.communicators._memory_utility as memory_utility_module
from mpi4py import MPI as mpi4py_module
self.memory_utility_module = memory_utility_module
self.mpi4py_module = mpi4py_module
评论列表
文章目录