def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
m = tf.cast(K_XX.get_shape()[0], tf.float32)
n = tf.cast(K_YY.get_shape()[0], tf.float32)
# m=50
#n=50
if biased:
mmd2 = (tf.reduce_sum(K_XX) / (m * m)
+ tf.reduce_sum(K_YY) / (n * n)
- 2 * tf.reduce_sum(K_XY) / (m * n))
else:
if const_diagonal is not False:
trace_X = m * const_diagonal
trace_Y = n * const_diagonal
else:
trace_X = tf.trace(K_XX)
trace_Y = tf.trace(K_YY)
mmd2 = ((tf.reduce_sum(K_XX) - trace_X) / (m * (m - 1))
+ (tf.reduce_sum(K_YY) - trace_Y) / (n * (n - 1))
- 2 * tf.reduce_sum(K_XY) / (m * n))
return mmd2
评论列表
文章目录