util.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def BatchClipByL2norm(t, upper_bound, name=None):
    """Clip an array of tensors by L2 norm.
    Shrink each dimension-0 slice of tensor (for matrix it is each row) such
    that the l2 norm is at most upper_bound. Here we clip each row as it
    corresponds to each example in the batch.

    Args:
      t: the input tensor.
      upper_bound: the upperbound of the L2 norm.
      name: optional name.

    Returns:
      the clipped tensor.
    """

    assert upper_bound > 0
    with tf.name_scope(values=[t, upper_bound], name=name,
                       default_name="batch_clip_by_l2norm") as name:
        saved_shape = tf.shape(t)
        batch_size = tf.slice(saved_shape, [0], [1])
        t2 = tf.reshape(t, tf.concat(axis=0, values=[batch_size, [-1]]))
        upper_bound_inv = tf.fill(tf.slice(saved_shape, [0], [1]),
                                  tf.constant(1.0 / upper_bound))
        # Add a small number to avoid divide by 0
        l2norm_inv = tf.rsqrt(tf.reduce_sum(t2 * t2, [1]) + 0.000001)
        scale = tf.minimum(l2norm_inv, upper_bound_inv) * upper_bound
        clipped_t = tf.matmul(tf.diag(scale), t2)
        clipped_t = tf.reshape(clipped_t, saved_shape, name=name)
        return clipped_t
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号