gan_losses.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _numerically_stable_global_norm(tensor_list):
    """Compute the global norm of a list of Tensors, with improved stability.

    The global norm computation sometimes overflows due to the intermediate L2
    step. To avoid this, we divide by a cheap-to-compute max over the
    matrix elements.

    Args:
      tensor_list: A list of tensors, or `None`.

    Returns:
      A scalar tensor with the global norm.
    """
    if np.all([x is None for x in tensor_list]):
        return 0.0

    list_max = tf.reduce_max([tf.reduce_max(tf.abs(x)) for x in
                              tensor_list if x is not None])
    return list_max * tf.global_norm([x / list_max for x in tensor_list
                                      if x is not None])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号