tensor.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:odin 作者: imito 项目源码 文件源码
def renorm_rms(X, axis=1, target_rms=1.0, name="RescaleRMS"):
  """ Scales the data such that RMS of the features dimension is 1.0
  scale = sqrt(x^t x / (D * target_rms^2)).

  NOTE
  ----
  by defaults, assume the features dimension is `1`
  """
  with tf.variable_scope(name):
    D = tf.sqrt(tf.cast(tf.shape(X)[axis], X.dtype.base_dtype))
    l2norm = tf.sqrt(tf.reduce_sum(X ** 2, axis=axis, keep_dims=True))
    X_rms = l2norm / D
    X_rms = tf.where(tf.equal(X_rms, 0.),
                     x=tf.ones_like(X_rms, dtype=X_rms.dtype.base_dtype),
                     y=X_rms)
    return target_rms * X / X_rms


# ===========================================================================
# RNN and loop
# ===========================================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号