ops.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:hyperchamber 作者: 255BITS 项目源码 文件源码
def __call__(self, x):
        shape = x.get_shape()
        shp = self.in_dim or shape[-1]
        with tf.variable_scope(self.name) as scope:
            self.gamma = tf.get_variable("gamma", [shp],
                                         initializer=tf.random_normal_initializer(1., 0.02))
            self.beta = tf.get_variable("beta", [shp],
                                        initializer=tf.constant_initializer(0.))

            self.mean, self.variance = tf.nn.moments(x, [0, 1, 2])
            self.mean.set_shape((shp,))
            self.variance.set_shape((shp,))
            self.ema_apply_op = self.ema.apply([self.mean, self.variance])

            if self.train:
                # with tf.control_dependencies([self.ema_apply_op]):
                normalized_x = tf.nn.batch_norm_with_global_normalization(
                        x, self.mean, self.variance, self.beta, self.gamma, self.epsilon,
                        scale_after_normalization=True)
            else:
                normalized_x = tf.nn.batch_norm_with_global_normalization(
                    x, self.ema.average(self.mean), self.ema.average(self.variance), self.beta,
                    self.gamma, self.epsilon,
                    scale_after_normalization=True)
            return normalized_x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号