discriminator.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tanda 作者: HazyResearch 项目源码 文件源码
def _tr_term(self, logits_arr, Np):
        """Get the TR reg term given a loits_arr consisting of Np
        different logits (number of classes = K) of transformations of batches
        of size B. This term is just the average squared distance between the
        logits of a pair of passes for a data point, averaged over the batch.

        See https://papers.nips.cc/paper/6333-regularization-with-stochastic-
        transformations-and-perturbations-for-deep-semi-supervised-learning.pdf
        """
        # Reshape to [B, Np, K]
        A = tf.transpose(logits_arr.stack(), [1, 0, 2])

        # ||a_{ij}||_2^2; note element-wise multiply here
        R = tf.reshape(tf.reduce_sum(A * A, 2), [-1, Np, 1])
        # ||a_{ji}||_2^2
        R_t = tf.transpose(R, [0, 2, 1])
        # a_{ij}a_{ji}
        S = tf.matmul(A, tf.transpose(A, [0, 2, 1]))
        # Pairwise distance matrix (a_{ij} - a_{ji})^2
        D = R - 2 * S + R_t

        # Lower triangular part (don't double count)
        D_lt = tf.matrix_band_part(D, -1, 0)
        # Take mean across over distinct pairs & batch size
        return tf.reduce_mean(tf.reduce_sum(D_lt, axis=2))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号