link.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:instance_normalization_chainer 作者: crcrpar 项目源码 文件源码
def __call__(self, x, gamma_=None, beta_=None):
        if hasattr(self, 'gamma'):
            gamma = self.gamma
        elif gamma_ is not None:
            gamma = gamma_
        else:
            with cuda.get_device_from_id(self._device_id):
                gamma = variable.Variable(self.xp.ones(
                    self.avg_mean.shape, dtype=x.dtype))
        if hasattr(self, 'beta'):
            beta = self.beta
        elif beta_ is not None:
            beta = beta_
        else:
            with cuda.get_device_from_id(self._device_id):
                beta = variable.Variable(self.xp.zeros(
                    self.avg_mean.shape, dtype=x.dtype))

        decay = self.decay
        if (not configuration.config.train) and self.valid_test:
            mean = variable.Variable(self.avg_mean)
            var = variable.Variable(self.avg_var)
            ret = fixed_instance_normalization(
                x, gamma, beta, mean, var, self.eps)
        else:
            func = InstanceNormalizationFunction(
                self.eps, self.avg_mean, self.avg_var, decay)
            ret = func(x, gamma, beta)
            self.avg_mean = func.running_mean
            self.avg_var = func.running_var

        return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号