custom_ops.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:how_to_convert_text_to_images 作者: llSourcell 项目源码 文件源码
def __call__(self, input_layer, epsilon=1e-5, decay=0.9, name="batch_norm",
                 in_dim=None, phase=Phase.train):
        shape = input_layer.shape
        shp = in_dim or shape[-1]
        with tf.variable_scope(name) as scope:
            self.mean = self.variable('mean', [shp], init=tf.constant_initializer(0.), train=False)
            self.variance = self.variable('variance', [shp], init=tf.constant_initializer(1.0), train=False)

            self.gamma = self.variable("gamma", [shp], init=tf.random_normal_initializer(1., 0.02))
            self.beta = self.variable("beta", [shp], init=tf.constant_initializer(0.))

            if phase == Phase.train:
                mean, variance = tf.nn.moments(input_layer.tensor, [0, 1, 2])
                mean.set_shape((shp,))
                variance.set_shape((shp,))

                update_moving_mean = moving_averages.assign_moving_average(self.mean, mean, decay)
                update_moving_variance = moving_averages.assign_moving_average(self.variance, variance, decay)

                with tf.control_dependencies([update_moving_mean, update_moving_variance]):
                    normalized_x = tf.nn.batch_norm_with_global_normalization(
                        input_layer.tensor, mean, variance, self.beta, self.gamma, epsilon,
                        scale_after_normalization=True)
            else:
                normalized_x = tf.nn.batch_norm_with_global_normalization(
                    input_layer.tensor, self.mean, self.variance,
                    self.beta, self.gamma, epsilon,
                    scale_after_normalization=True)
            return input_layer.with_tensor(normalized_x, parameters=self.vars)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号