def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
m = tf.cast(K_XX.get_shape()[0], tf.float32) # Assumes X, Y are same shape
### Get the various sums of kernels that we'll use
# Kts drop the diagonal, but we don't need to compute them explicitly
if const_diagonal is not False:
const_diagonal = tf.cast(const_diagonal, tf.float32)
diag_X = diag_Y = const_diagonal
sum_diag_X = sum_diag_Y = m * const_diagonal
sum_diag2_X = sum_diag2_Y = m * const_diagonal**2
else:
diag_X = tf.diag_part(K_XX)
diag_Y = tf.diag_part(K_YY)
sum_diag_X = tf.reduce_sum(diag_X)
sum_diag_Y = tf.reduce_sum(diag_Y)
sum_diag2_X = sq_sum(diag_X)
sum_diag2_Y = sq_sum(diag_Y)
Kt_XX_sums = tf.reduce_sum(K_XX, 1) - diag_X
Kt_YY_sums = tf.reduce_sum(K_YY, 1) - diag_Y
K_XY_sums_0 = tf.reduce_sum(K_XY, 0)
K_XY_sums_1 = tf.reduce_sum(K_XY, 1)
Kt_XX_sum = tf.reduce_sum(Kt_XX_sums)
Kt_YY_sum = tf.reduce_sum(Kt_YY_sums)
K_XY_sum = tf.reduce_sum(K_XY_sums_0)
Kt_XX_2_sum = sq_sum(K_XX) - sum_diag2_X
Kt_YY_2_sum = sq_sum(K_YY) - sum_diag2_Y
K_XY_2_sum = sq_sum(K_XY)
if biased:
mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
+ (Kt_YY_sum + sum_diag_Y) / (m * m)
- 2 * K_XY_sum / (m * m))
else:
mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * (m-1))
+ (Kt_YY_sum + sum_diag_Y) / (m * (m-1))
- 2 * K_XY_sum / (m * m))
var_est = (
2 / (m**2 * (m-1)**2) * (
2 * sq_sum(Kt_XX_sums) - Kt_XX_2_sum
+ 2 * sq_sum(Kt_YY_sums) - Kt_YY_2_sum)
- (4*m-6) / (m**3 * (m-1)**3) * (Kt_XX_sum**2 + Kt_YY_sum**2)
+ 4*(m-2) / (m**3 * (m-1)**2) * (
sq_sum(K_XY_sums_1) + sq_sum(K_XY_sums_0))
- 4 * (m-3) / (m**3 * (m-1)**2) * K_XY_2_sum
- (8*m - 12) / (m**5 * (m-1)) * K_XY_sum**2
+ 8 / (m**3 * (m-1)) * (
1/m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
- dot(Kt_XX_sums, K_XY_sums_1)
- dot(Kt_YY_sums, K_XY_sums_0))
)
return mmd2, var_est
评论列表
文章目录