def _tr_term(self, logits_arr, Np):
"""Get the TR reg term given a loits_arr consisting of Np
different logits (number of classes = K) of transformations of batches
of size B. This term is just the average squared distance between the
logits of a pair of passes for a data point, averaged over the batch.
See https://papers.nips.cc/paper/6333-regularization-with-stochastic-
transformations-and-perturbations-for-deep-semi-supervised-learning.pdf
"""
# Reshape to [B, Np, K]
A = tf.transpose(logits_arr.stack(), [1, 0, 2])
# ||a_{ij}||_2^2; note element-wise multiply here
R = tf.reshape(tf.reduce_sum(A * A, 2), [-1, Np, 1])
# ||a_{ji}||_2^2
R_t = tf.transpose(R, [0, 2, 1])
# a_{ij}a_{ji}
S = tf.matmul(A, tf.transpose(A, [0, 2, 1]))
# Pairwise distance matrix (a_{ij} - a_{ji})^2
D = R - 2 * S + R_t
# Lower triangular part (don't double count)
D_lt = tf.matrix_band_part(D, -1, 0)
# Take mean across over distinct pairs & batch size
return tf.reduce_mean(tf.reduce_sum(D_lt, axis=2))
评论列表
文章目录