batch_normalization.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:chainermn 作者: chainer 项目源码 文件源码
def __init__(self, comm, eps=2e-5, mean=None, var=None, decay=0.9):
        chainer.utils.experimental(
            'chainermn.functions.MultiNodeBatchNormalizationFunction')

        self.comm = comm
        self.running_mean = mean
        self.running_var = var

        # Note: cuDNN v5 requires that eps be greater than 1e-5. Otherwise, an
        # error will occur.
        # See CUDNN_BN_MIN_EPSILON value in cudnn.h to verify minimum allowable
        # value.
        self.eps = eps
        if chainer.should_use_cudnn('>=auto'):
            if eps < 1e-5:
                msg = 'cuDNN does not allow an eps value less than 1e-5.'
                raise RuntimeError(msg)
        self.mean_cache = None
        self.decay = decay

        # We need to delay importing MPI4py (and momdules that import MPI4py)
        import chainermn.communicators._memory_utility as memory_utility_module
        from mpi4py import MPI as mpi4py_module
        self.memory_utility_module = memory_utility_module
        self.mpi4py_module = mpi4py_module
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号