def covar_loss(self, top_states):
""""""
n_dims = len(top_states.get_shape().as_list())
hidden_size = top_states.get_shape().as_list()[-1]
n_tokens = tf.to_float(self.n_tokens)
I = tf.diag(tf.ones([hidden_size]))
if n_dims == 3:
top_states = top_states * self.tokens_to_keep3D
n_tokens = self.n_tokens
elif n_dims == 4:
top_states = top_states * tf.expand_dims(self.tokens_to_keep3D, 1) * tf.expand_dims(self.tokens_to_keep3D, 2)
n_tokens = self.n_tokens**2
top_states = tf.reshape(top_states * self.tokens_to_keep3D, [-1, hidden_size])
means = tf.reduce_sum(top_states, 0, keep_dims=True) / n_tokens
centered_states = top_states - means
covar_mat = tf.matmul(centered_states, centered_states, transpose_a=True) / n_tokens
off_diag_covar_mat = covar_mat * (1-I)
return tf.nn.l2_loss(off_diag_covar_mat)
#=============================================================
评论列表
文章目录