model.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:streetview 作者: ydnaandy123 项目源码 文件源码
def _normalize(self, x, mean, mean_sq, message):
        # make sure this is called with a variable scope
        shape = x.get_shape().as_list()
        assert len(shape) == 4
        self.gamma_driver = tf.get_variable("gamma_driver", [shape[-1]],
                                initializer=tf.random_normal_initializer(0., 0.02))
        gamma = tf.exp(self.gamma_driver)
        gamma = tf.reshape(gamma, [1, 1, 1, -1])
        self.beta = tf.get_variable("beta", [shape[-1]],
                                initializer=tf.constant_initializer(0.))
        beta = tf.reshape(self.beta, [1, 1, 1, -1])
        assert self.epsilon is not None
        assert mean_sq is not None
        assert mean is not None
        std = tf.sqrt(self.epsilon + mean_sq - tf.square(mean))
        out = x - mean
        out = out / std
        # out = tf.Print(out, [tf.reduce_mean(out, [0, 1, 2]),
        #    tf.reduce_mean(tf.square(out - tf.reduce_mean(out, [0, 1, 2], keep_dims=True)), [0, 1, 2])],
        #    message, first_n=-1)
        out = out * gamma
        out = out + beta
        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号