pot.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:adagan 作者: tolstikhin 项目源码 文件源码
def correlation_loss(self, opts, input_):
        """
        Independence test based on Pearson's correlation.
        Keep in mind that this captures only linear dependancies.
        However, for multivariate Gaussian independence is equivalent
        to zero correlation.
        """

        batch_size = self.get_batch_size(opts, input_)
        dim = int(input_.get_shape()[1])
        transposed = tf.transpose(input_, perm=[1, 0])
        mean = tf.reshape(tf.reduce_mean(transposed, axis=1), [-1, 1])
        centered_transposed = transposed - mean # Broadcasting mean
        cov = tf.matmul(centered_transposed, centered_transposed, transpose_b=True)
        cov = cov / (batch_size - 1)
        #cov = tf.Print(cov, [cov], "cov")
        sigmas = tf.sqrt(tf.diag_part(cov) + 1e-5)
        #sigmas = tf.Print(sigmas, [sigmas], "sigmas")
        sigmas = tf.reshape(sigmas, [1, -1])
        sigmas = tf.matmul(sigmas, sigmas, transpose_a=True)
        #sigmas = tf.Print(sigmas, [sigmas], "sigmas")
        # Pearson's correlation
        corr = cov / sigmas
        triangle = tf.matrix_set_diag(tf.matrix_band_part(corr, 0, -1), tf.zeros(dim))
        #triangle = tf.Print(triangle, [triangle], "triangle")
        loss = tf.reduce_sum(tf.square(triangle)) / ((dim * dim - dim) / 2.0)
        #loss = tf.Print(loss, [loss], "Correlation loss")
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号