def discriminator_lks_test(self, opts, input_):
"""Deterministic discriminator using Kernel Stein Discrepancy test
refer to the quadratic test of https://arxiv.org/pdf/1705.07673.pdf
The statistic basically reads:
\[
\frac{1}{n^2 - n}\sum_{i \neq j} \left(
frac{<x_i, x__j>}{\sigma_p^4}
+ d/\sigma_k^2
- \|x_i - x_j\|^2\left(\frac{1}{\sigma_p^2\sigma_k^2} + \frac{1}{\sigma_k^4}\right)
\right)
\exp( - \|x_i - x_j\|^2/2/\sigma_k^2)
\]
"""
n = self.get_batch_size(opts, input_)
n = tf.cast(n, tf.int32)
half_size = (n * n - n) / 2
nf = tf.cast(n, tf.float32)
norms = tf.reduce_sum(tf.square(input_), axis=1, keep_dims=True)
dotprods = tf.matmul(input_, input_, transpose_b=True)
distances = norms + tf.transpose(norms) - 2. * dotprods
sigma2_p = opts['pot_pz_std'] ** 2 # var = std ** 2
# Median heuristic for the sigma^2 of Gaussian kernel
# sigma2_k = tf.nn.top_k(tf.reshape(distances, [-1]), half_size).values[half_size - 1]
# Maximal heuristic for the sigma^2 of Gaussian kernel
# sigma2_k = tf.nn.top_k(tf.reshape(distances, [-1]), 1).values[0]
sigma2_k = opts['latent_space_dim'] * sigma2_p
if opts['verbose'] == 2:
sigma2_k = tf.Print(sigma2_k, [tf.nn.top_k(tf.reshape(distances, [-1]), 1).values[0]],
'Maximal squared pairwise distance:')
sigma2_k = tf.Print(sigma2_k, [tf.reduce_mean(distances)],
'Average squared pairwise distance:')
sigma2_k = tf.Print(sigma2_k, [sigma2_k], 'Kernel width:')
res = dotprods / sigma2_p ** 2 \
- distances * (1. / sigma2_p / sigma2_k + 1. / sigma2_k ** 2) \
+ opts['latent_space_dim'] / sigma2_k
res = tf.multiply(res, tf.exp(- distances / 2./ sigma2_k))
res = tf.multiply(res, 1. - tf.eye(n))
stat = tf.reduce_sum(res) / (nf * nf - nf)
# stat = tf.reduce_sum(res) / (nf * nf)
return stat
评论列表
文章目录