def renorm_rms(X, axis=1, target_rms=1.0, name="RescaleRMS"):
""" Scales the data such that RMS of the features dimension is 1.0
scale = sqrt(x^t x / (D * target_rms^2)).
NOTE
----
by defaults, assume the features dimension is `1`
"""
with tf.variable_scope(name):
D = tf.sqrt(tf.cast(tf.shape(X)[axis], X.dtype.base_dtype))
l2norm = tf.sqrt(tf.reduce_sum(X ** 2, axis=axis, keep_dims=True))
X_rms = l2norm / D
X_rms = tf.where(tf.equal(X_rms, 0.),
x=tf.ones_like(X_rms, dtype=X_rms.dtype.base_dtype),
y=X_rms)
return target_rms * X / X_rms
# ===========================================================================
# RNN and loop
# ===========================================================================
评论列表
文章目录