batch_normalization.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:chainermn 作者: chainer 项目源码 文件源码
def __init__(self, size, comm, decay=0.9, eps=2e-5, dtype=numpy.float32,
                 use_gamma=True, use_beta=True,
                 initial_gamma=None, initial_beta=None):
        chainer.utils.experimental(
            'chainermn.links.MultiNodeBatchNormalization')

        if chainer.__version__.startswith('1.'):
            raise RuntimeError(
                'MultiNodeBatchNormalization works only with '
                'chainer >= 2.0.0.')

        super(MultiNodeBatchNormalization, self).__init__()
        self.comm = comm
        self.avg_mean = numpy.zeros(size, dtype=dtype)
        self.register_persistent('avg_mean')
        self.avg_var = numpy.zeros(size, dtype=dtype)
        self.register_persistent('avg_var')
        self.N = 0
        self.register_persistent('N')
        self.decay = decay
        self.eps = eps

        with self.init_scope():
            if use_gamma:
                if initial_gamma is None:
                    initial_gamma = 1
                initial_gamma = initializers._get_initializer(initial_gamma)
                initial_gamma.dtype = dtype
                self.gamma = variable.Parameter(initial_gamma, size)
            if use_beta:
                if initial_beta is None:
                    initial_beta = 0
                initial_beta = initializers._get_initializer(initial_beta)
                initial_beta.dtype = dtype
                self.beta = variable.Parameter(initial_beta, size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号