def correlation_loss(self, opts, input_):
"""
Independence test based on Pearson's correlation.
Keep in mind that this captures only linear dependancies.
However, for multivariate Gaussian independence is equivalent
to zero correlation.
"""
batch_size = self.get_batch_size(opts, input_)
dim = int(input_.get_shape()[1])
transposed = tf.transpose(input_, perm=[1, 0])
mean = tf.reshape(tf.reduce_mean(transposed, axis=1), [-1, 1])
centered_transposed = transposed - mean # Broadcasting mean
cov = tf.matmul(centered_transposed, centered_transposed, transpose_b=True)
cov = cov / (batch_size - 1)
#cov = tf.Print(cov, [cov], "cov")
sigmas = tf.sqrt(tf.diag_part(cov) + 1e-5)
#sigmas = tf.Print(sigmas, [sigmas], "sigmas")
sigmas = tf.reshape(sigmas, [1, -1])
sigmas = tf.matmul(sigmas, sigmas, transpose_a=True)
#sigmas = tf.Print(sigmas, [sigmas], "sigmas")
# Pearson's correlation
corr = cov / sigmas
triangle = tf.matrix_set_diag(tf.matrix_band_part(corr, 0, -1), tf.zeros(dim))
#triangle = tf.Print(triangle, [triangle], "triangle")
loss = tf.reduce_sum(tf.square(triangle)) / ((dim * dim - dim) / 2.0)
#loss = tf.Print(loss, [loss], "Correlation loss")
return loss
评论列表
文章目录