def l2_batch_normalize(x, epsilon=1e-12, scope=None):
"""
Helper function to normalize a batch of vectors.
:param x: the input placeholder
:param epsilon: stabilizes division
:return: the batch of l2 normalized vector
"""
with tf.name_scope(scope, "l2_batch_normalize") as scope:
x_shape = tf.shape(x)
x = tf.contrib.layers.flatten(x)
x /= (epsilon + tf.reduce_max(tf.abs(x), 1, keep_dims=True))
square_sum = tf.reduce_sum(tf.square(x), 1, keep_dims=True)
x_inv_norm = tf.rsqrt(np.sqrt(epsilon) + square_sum)
x_norm = tf.multiply(x, x_inv_norm)
return tf.reshape(x_norm, x_shape, scope)
评论列表
文章目录