pot.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:adagan 作者: tolstikhin 项目源码 文件源码
def discriminator_lks_test(self, opts, input_):
        """Deterministic discriminator using Kernel Stein Discrepancy test
        refer to the quadratic test of https://arxiv.org/pdf/1705.07673.pdf

        The statistic basically reads:
            \[
                \frac{1}{n^2 - n}\sum_{i \neq j} \left(
                    frac{<x_i, x__j>}{\sigma_p^4}
                    + d/\sigma_k^2
                    - \|x_i - x_j\|^2\left(\frac{1}{\sigma_p^2\sigma_k^2} + \frac{1}{\sigma_k^4}\right)
                \right)
                \exp( - \|x_i - x_j\|^2/2/\sigma_k^2)
            \]

        """
        n = self.get_batch_size(opts, input_)
        n = tf.cast(n, tf.int32)
        half_size = (n * n - n) / 2
        nf = tf.cast(n, tf.float32)
        norms = tf.reduce_sum(tf.square(input_), axis=1, keep_dims=True)
        dotprods = tf.matmul(input_, input_, transpose_b=True)
        distances = norms + tf.transpose(norms) - 2. * dotprods
        sigma2_p = opts['pot_pz_std'] ** 2 # var = std ** 2
        # Median heuristic for the sigma^2 of Gaussian kernel
        # sigma2_k = tf.nn.top_k(tf.reshape(distances, [-1]), half_size).values[half_size - 1]
        # Maximal heuristic for the sigma^2 of Gaussian kernel
        # sigma2_k = tf.nn.top_k(tf.reshape(distances, [-1]), 1).values[0]
        sigma2_k = opts['latent_space_dim'] * sigma2_p
        if opts['verbose'] == 2:
            sigma2_k = tf.Print(sigma2_k, [tf.nn.top_k(tf.reshape(distances, [-1]), 1).values[0]],
                                'Maximal squared pairwise distance:')
            sigma2_k = tf.Print(sigma2_k, [tf.reduce_mean(distances)],
                                'Average squared pairwise distance:')
            sigma2_k = tf.Print(sigma2_k, [sigma2_k], 'Kernel width:')
        res = dotprods / sigma2_p ** 2 \
              - distances * (1. / sigma2_p / sigma2_k + 1. / sigma2_k ** 2) \
              + opts['latent_space_dim'] / sigma2_k
        res = tf.multiply(res, tf.exp(- distances / 2./ sigma2_k))
        res = tf.multiply(res, 1. - tf.eye(n))
        stat = tf.reduce_sum(res) / (nf * nf - nf)
        # stat = tf.reduce_sum(res) / (nf * nf)
        return stat
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号