nn.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:Sing_Par 作者: wanghm92 项目源码 文件源码
def covar_loss(self, top_states):
    """"""

    n_dims = len(top_states.get_shape().as_list())
    hidden_size = top_states.get_shape().as_list()[-1]
    n_tokens = tf.to_float(self.n_tokens)
    I = tf.diag(tf.ones([hidden_size]))

    if n_dims == 3:
      top_states = top_states * self.tokens_to_keep3D
      n_tokens = self.n_tokens
    elif n_dims == 4:
      top_states = top_states * tf.expand_dims(self.tokens_to_keep3D, 1) * tf.expand_dims(self.tokens_to_keep3D, 2)
      n_tokens = self.n_tokens**2
    top_states = tf.reshape(top_states * self.tokens_to_keep3D, [-1, hidden_size])
    means = tf.reduce_sum(top_states, 0, keep_dims=True) / n_tokens
    centered_states = top_states - means
    covar_mat = tf.matmul(centered_states, centered_states, transpose_a=True) / n_tokens
    off_diag_covar_mat = covar_mat * (1-I)
    return tf.nn.l2_loss(off_diag_covar_mat)

  #=============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号